diff --git a/autoPyTorch/pipeline/components/training/losses.py b/autoPyTorch/pipeline/components/training/losses.py index 84d914f1c..de4578fbd 100644 --- a/autoPyTorch/pipeline/components/training/losses.py +++ b/autoPyTorch/pipeline/components/training/losses.py @@ -1,3 +1,15 @@ +""" +Loss functions available in autoPyTorch + +Classification: + CrossEntropyLoss: supports multiclass, binary output types + BCEWithLogitsLoss: supports binary output types + Default: CrossEntropyLoss +Regression: + MSELoss: supports continuous output types + L1Loss: supports continuous output types + Default: MSELoss +""" from typing import Any, Dict, Optional, Type from torch.nn.modules.loss import ( @@ -11,21 +23,30 @@ from autoPyTorch.constants import BINARY, CLASSIFICATION_TASKS, CONTINUOUS, MULTICLASS, REGRESSION_TASKS, \ STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES, TASK_TYPES_TO_STRING + losses = dict(classification=dict( CrossEntropyLoss=dict( - module=CrossEntropyLoss, supported_output_type=MULTICLASS), + module=CrossEntropyLoss, supported_output_types=[MULTICLASS, BINARY]), BCEWithLogitsLoss=dict( - module=BCEWithLogitsLoss, supported_output_type=BINARY)), + module=BCEWithLogitsLoss, supported_output_types=[BINARY])), regression=dict( MSELoss=dict( - module=MSELoss, supported_output_type=CONTINUOUS), + module=MSELoss, supported_output_types=[CONTINUOUS]), L1Loss=dict( - module=L1Loss, supported_output_type=CONTINUOUS))) + module=L1Loss, supported_output_types=[CONTINUOUS]))) default_losses = dict(classification=CrossEntropyLoss, regression=MSELoss) def get_default(task: int) -> Type[Loss]: + """ + Utility function to get default loss for the task + Args: + task (int): + + Returns: + Type[torch.nn.modules.loss._Loss] + """ if task in CLASSIFICATION_TASKS: return default_losses['classification'] elif task in REGRESSION_TASKS: @@ -35,19 +56,42 @@ def get_default(task: int) -> Type[Loss]: def get_supported_losses(task: int, output_type: int) -> Dict[str, Type[Loss]]: + """ + Utility function to get supported losses for a given task and output type + Args: + task (int): integer identifier for the task + output_type: integer identifier for the output type of the task + + Returns: + Returns a dictionary containing the losses supported for the given + inputs. Key-Name, Value-Module + """ supported_losses = dict() if task in CLASSIFICATION_TASKS: for key, value in losses['classification'].items(): - if output_type == value['supported_output_type']: + if output_type in value['supported_output_types']: supported_losses[key] = value['module'] elif task in REGRESSION_TASKS: for key, value in losses['regression'].items(): - if output_type == value['supported_output_type']: + if output_type in value['supported_output_types']: supported_losses[key] = value['module'] return supported_losses -def get_loss_instance(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]: +def get_loss(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]: + """ + Utility function to get losses for the given dataset properties. + If name is mentioned, checks if the loss is compatible with + the dataset properties and returns the specific loss + Args: + dataset_properties (Dict[str, Any]): Dictionary containing + properties of the dataset. Must contain task_type and + output_type as strings. + name (Optional[str]): name of the specific loss + + Returns: + Type[torch.nn.modules.loss._Loss] + """ assert 'task_type' in dataset_properties, \ "Expected dataset_properties to have task_type got {}".format(dataset_properties.keys()) assert 'output_type' in dataset_properties, \ diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index 26109fca6..ee8ea87bf 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -1,5 +1,5 @@ import time -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np @@ -10,6 +10,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.tensorboard.writer import SummaryWriter + from autoPyTorch.constants import REGRESSION_TASKS from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score @@ -173,14 +174,13 @@ def prepare( self, metrics: List[Any], model: torch.nn.Module, - criterion: torch.nn.Module, + criterion: Type[torch.nn.Module], budget_tracker: BudgetTracker, optimizer: Optimizer, device: torch.device, metrics_during_training: bool, scheduler: _LRScheduler, task_type: int, - output_type: int, labels: Union[np.ndarray, torch.Tensor, pd.DataFrame] ) -> None: @@ -191,19 +191,12 @@ def prepare( self.metrics = metrics # Weights for the loss function - weights = None - kwargs: Dict[str, Any] = {} - # if self.weighted_loss: - # weights = self.get_class_weights(output_type, labels) - # if output_type == BINARY: - # kwargs['pos_weight'] = weights - # pass - # else: - # kwargs['weight'] = weights + kwargs = {} + if self.weighted_loss: + kwargs = self.get_class_weights(criterion, labels) # Setup the loss function - self.criterion = criterion(**kwargs) if weights is not None else criterion() - + self.criterion = criterion(**kwargs) # setup the model self.model = model.to(device) @@ -384,13 +377,16 @@ def compute_metrics(self, outputs_data: np.ndarray, targets_data: np.ndarray targets_data = torch.cat(targets_data, dim=0).numpy() return calculate_score(targets_data, outputs_data, self.task_type, self.metrics) - def get_class_weights(self, output_type: int, labels: Union[np.ndarray, torch.Tensor, pd.DataFrame] - ) -> np.ndarray: - strategy = get_loss_weight_strategy(output_type) + def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.ndarray, torch.Tensor, pd.DataFrame] + ) -> Dict[str, np.ndarray]: + strategy = get_loss_weight_strategy(criterion) weights = strategy(y=labels) weights = torch.from_numpy(weights) weights = weights.float().to(self.device) - return weights + if criterion.__name__ == 'BCEWithLogitsLoss': + return {'pos_weight': weights} + else: + return {'weight': weights} def data_preparation(self, X: np.ndarray, y: np.ndarray, ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py index 667dd1ac6..88f7dd963 100755 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py @@ -19,14 +19,14 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.tensorboard.writer import SummaryWriter -from autoPyTorch.constants import STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES +from autoPyTorch.constants import STRING_TO_TASK_TYPES from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice from autoPyTorch.pipeline.components.base_component import ( ThirdPartyComponents, autoPyTorchComponent, find_components, ) -from autoPyTorch.pipeline.components.training.losses import get_loss_instance +from autoPyTorch.pipeline.components.training.losses import get_loss from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics from autoPyTorch.pipeline.components.training.trainer.base_trainer import ( BaseTrainerComponent, @@ -265,15 +265,14 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Modu model=X['network'], metrics=get_metrics(dataset_properties=X['dataset_properties'], names=additional_metrics), - criterion=get_loss_instance(X['dataset_properties'], - name=additional_losses), + criterion=get_loss(X['dataset_properties'], + name=additional_losses), budget_tracker=budget_tracker, optimizer=X['optimizer'], device=get_device_from_fit_dictionary(X), metrics_during_training=X['metrics_during_training'], scheduler=X['lr_scheduler'], task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']], - output_type=STRING_TO_OUTPUT_TYPES[X['dataset_properties']['output_type']], labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]] ) total_parameter_count, trainable_parameter_count = self.count_parameters(X['network']) diff --git a/autoPyTorch/utils/implementations.py b/autoPyTorch/utils/implementations.py index 15f1758e1..2130cfd6b 100644 --- a/autoPyTorch/utils/implementations.py +++ b/autoPyTorch/utils/implementations.py @@ -1,17 +1,24 @@ -from typing import Callable, Union +from typing import Any, Callable, Dict, Type, Union import numpy as np import torch -from autoPyTorch.constants import BINARY - -def get_loss_weight_strategy(output_type: int) -> Callable: - if output_type == BINARY: +def get_loss_weight_strategy(loss: Type[torch.nn.Module]) -> Callable: + """ + Utility function that returns strategy for the given loss + Args: + loss (Type[torch.nn.Module]): type of the loss function + Returns: + (Callable): Relevant Callable strategy + """ + if loss.__name__ in LossWeightStrategyWeightedBinary.get_properties()['supported_losses']: return LossWeightStrategyWeightedBinary() - else: + elif loss.__name__ in LossWeightStrategyWeighted.get_properties()['supported_losses']: return LossWeightStrategyWeighted() + else: + raise ValueError("No strategy currently supports the given loss, {}".format(loss.__name__)) class LossWeightStrategyWeighted(): @@ -34,6 +41,10 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray: return weights + @staticmethod + def get_properties() -> Dict[str, Any]: + return {'supported_losses': ['CrossEntropyLoss']} + class LossWeightStrategyWeightedBinary(): def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray: @@ -46,3 +57,7 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray: weights = counts_zero / np.maximum(counts_one, 1) return np.array(weights) + + @staticmethod + def get_properties() -> Dict[str, Any]: + return {'supported_losses': ['BCEWithLogitsLoss']} diff --git a/test/conftest.py b/test/conftest.py index a5d0fe0af..f05f573a7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -14,6 +14,8 @@ from sklearn.datasets import fetch_openml, make_classification, make_regression +import torch + from autoPyTorch.data.tabular_validator import TabularInputValidator from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.utils.backend import create @@ -357,3 +359,61 @@ def error_search_space_updates(): value_range=[0, 0.5], default_value=0.2) return updates + + +@pytest.fixture +def loss_cross_entropy_multiclass(): + dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'multiclass'} + predictions = torch.randn(4, 4, requires_grad=True) + name = 'CrossEntropyLoss' + targets = torch.empty(4, dtype=torch.long).random_(4) + # to ensure we have all classes in the labels + while True: + labels = torch.empty(20, dtype=torch.long).random_(4) + if len(torch.unique(labels)) == 4: + break + + return dataset_properties, predictions, name, targets, labels + + +@pytest.fixture +def loss_cross_entropy_binary(): + dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'} + predictions = torch.randn(4, 2, requires_grad=True) + name = 'CrossEntropyLoss' + targets = torch.empty(4, dtype=torch.long).random_(2) + # to ensure we have all classes in the labels + while True: + labels = torch.empty(20, dtype=torch.long).random_(2) + if len(torch.unique(labels)) == 2: + break + return dataset_properties, predictions, name, targets, labels + + +@pytest.fixture +def loss_bce(): + dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'} + predictions = torch.empty(4).random_(2) + name = 'BCEWithLogitsLoss' + targets = torch.empty(4).random_(2) + # to ensure we have all classes in the labels + while True: + labels = torch.empty(20, dtype=torch.long).random_(2) + if len(torch.unique(labels)) == 2: + break + return dataset_properties, predictions, name, targets, labels + + +@pytest.fixture +def loss_mse(): + dataset_properties = {'task_type': 'tabular_regression', 'output_type': 'continuous'} + predictions = torch.randn(4) + name = 'MSELoss' + targets = torch.randn(4) + labels = None + return dataset_properties, predictions, name, targets, labels + + +@pytest.fixture +def loss_details(request): + return request.getfixturevalue(request.param) diff --git a/test/test_pipeline/components/base.py b/test/test_pipeline/components/base.py index 8adbbd48a..8211172e7 100644 --- a/test/test_pipeline/components/base.py +++ b/test/test_pipeline/components/base.py @@ -97,7 +97,6 @@ def prepare_trainer(self, device=device, metrics_during_training=True, task_type=task_type, - output_type=output_type, labels=y ) return trainer, model, optimizer, loader, criterion, epochs, logger diff --git a/test/test_pipeline/components/preprocessing/__init__.py b/test/test_pipeline/components/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/test_pipeline/components/test_encoder_choice.py b/test/test_pipeline/components/preprocessing/test_encoder_choice.py similarity index 100% rename from test/test_pipeline/components/test_encoder_choice.py rename to test/test_pipeline/components/preprocessing/test_encoder_choice.py diff --git a/test/test_pipeline/components/test_encoders.py b/test/test_pipeline/components/preprocessing/test_encoders.py similarity index 100% rename from test/test_pipeline/components/test_encoders.py rename to test/test_pipeline/components/preprocessing/test_encoders.py diff --git a/test/test_pipeline/components/test_feature_preprocessor.py b/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py similarity index 100% rename from test/test_pipeline/components/test_feature_preprocessor.py rename to test/test_pipeline/components/preprocessing/test_feature_preprocessor.py diff --git a/test/test_pipeline/components/test_feature_preprocessor_choice.py b/test/test_pipeline/components/preprocessing/test_feature_preprocessor_choice.py similarity index 100% rename from test/test_pipeline/components/test_feature_preprocessor_choice.py rename to test/test_pipeline/components/preprocessing/test_feature_preprocessor_choice.py diff --git a/test/test_pipeline/components/test_imputers.py b/test/test_pipeline/components/preprocessing/test_imputers.py similarity index 100% rename from test/test_pipeline/components/test_imputers.py rename to test/test_pipeline/components/preprocessing/test_imputers.py diff --git a/test/test_pipeline/components/test_normalizer_choice.py b/test/test_pipeline/components/preprocessing/test_normalizer_choice.py similarity index 100% rename from test/test_pipeline/components/test_normalizer_choice.py rename to test/test_pipeline/components/preprocessing/test_normalizer_choice.py diff --git a/test/test_pipeline/components/test_normalizers.py b/test/test_pipeline/components/preprocessing/test_normalizers.py similarity index 100% rename from test/test_pipeline/components/test_normalizers.py rename to test/test_pipeline/components/preprocessing/test_normalizers.py diff --git a/test/test_pipeline/components/test_scaler_choice.py b/test/test_pipeline/components/preprocessing/test_scaler_choice.py similarity index 100% rename from test/test_pipeline/components/test_scaler_choice.py rename to test/test_pipeline/components/preprocessing/test_scaler_choice.py diff --git a/test/test_pipeline/components/test_scalers.py b/test/test_pipeline/components/preprocessing/test_scalers.py similarity index 100% rename from test/test_pipeline/components/test_scalers.py rename to test/test_pipeline/components/preprocessing/test_scalers.py diff --git a/test/test_pipeline/components/test_tabular_column_transformer.py b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py similarity index 100% rename from test/test_pipeline/components/test_tabular_column_transformer.py rename to test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py diff --git a/test/test_pipeline/components/setup/__init__.py b/test/test_pipeline/components/setup/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/test_pipeline/components/test_setup.py b/test/test_pipeline/components/setup/test_setup.py similarity index 72% rename from test/test_pipeline/components/test_setup.py rename to test/test_pipeline/components/setup/test_setup.py index 07e2f2f03..9349c4ac8 100644 --- a/test/test_pipeline/components/test_setup.py +++ b/test/test_pipeline/components/setup/test_setup.py @@ -1,9 +1,10 @@ import copy -import unittest.mock from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace +import pytest + from sklearn.base import clone import torch @@ -33,6 +34,7 @@ BaseOptimizerComponent, OptimizerChoice ) +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates class DummyLR(BaseLRComponent): @@ -117,7 +119,7 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] return ConfigurationSpace() -class SchedulerTest(unittest.TestCase): +class TestScheduler: def test_every_scheduler_is_valid(self): """ Makes sure that every scheduler is a valid estimator. @@ -129,7 +131,7 @@ def test_every_scheduler_is_valid(self): scheduler_choice = SchedulerChoice(dataset_properties={}) # Make sure all components are returned - self.assertEqual(len(scheduler_choice.get_components().keys()), 7) + assert len(scheduler_choice.get_components().keys()) == 7 # For every scheduler in the components, make sure # that it complies with the scikit learn estimator. @@ -143,7 +145,7 @@ def test_every_scheduler_is_valid(self): # Make sure all keys are copied properly for k, v in estimator.get_params().items(): - self.assertIn(k, estimator_clone_params) + assert k in estimator_clone_params # Make sure the params getter of estimator are honored klass = estimator.__class__ @@ -156,7 +158,7 @@ def test_every_scheduler_is_valid(self): for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - self.assertEqual(param1, param2) + assert param1 == param2 def test_get_set_config_space(self): """Make sure that we can setup a valid choice in the scheduler @@ -165,10 +167,8 @@ def test_get_set_config_space(self): cs = scheduler_choice.get_hyperparameter_search_space() # Make sure that all hyperparameters are part of the serach space - self.assertListEqual( - sorted(cs.get_hyperparameter('__choice__').choices), - sorted(list(scheduler_choice.get_components().keys())) - ) + assert sorted(cs.get_hyperparameter('__choice__').choices) == \ + sorted(list(scheduler_choice.get_components().keys())) # Make sure we can properly set some random configs # Whereas just one iteration will make sure the algorithm works, @@ -179,8 +179,8 @@ def test_get_set_config_space(self): config_dict = copy.deepcopy(config.get_dictionary()) scheduler_choice.set_hyperparameters(config) - self.assertEqual(scheduler_choice.choice.__class__, - scheduler_choice.get_components()[config_dict['__choice__']]) + assert scheduler_choice.choice.__class__ == \ + scheduler_choice.get_components()[config_dict['__choice__']] # Then check the choice configuration selected_choice = config_dict.pop('__choice__', None) @@ -188,22 +188,22 @@ def test_get_set_config_space(self): # Remove the selected_choice string from the parameter # so we can query in the object for it key = key.replace(selected_choice + ':', '') - self.assertIn(key, vars(scheduler_choice.choice)) - self.assertEqual(value, scheduler_choice.choice.__dict__[key]) + assert key in vars(scheduler_choice.choice) + assert value == scheduler_choice.choice.__dict__[key] def test_scheduler_add(self): """Makes sure that a component can be added to the CS""" # No third party components to start with - self.assertEqual(len(lr_components._addons.components), 0) + assert len(lr_components._addons.components) == 0 # Then make sure the scheduler can be added and query'ed lr_components.add_scheduler(DummyLR) - self.assertEqual(len(lr_components._addons.components), 1) + assert len(lr_components._addons.components) == 1 cs = SchedulerChoice(dataset_properties={}).get_hyperparameter_search_space() - self.assertIn('DummyLR', str(cs)) + assert 'DummyLR' in str(cs) -class OptimizerTest(unittest.TestCase): +class OptimizerTest: def test_every_optimizer_is_valid(self): """ Makes sure that every optimizer is a valid estimator. @@ -215,7 +215,7 @@ def test_every_optimizer_is_valid(self): optimizer_choice = OptimizerChoice(dataset_properties={}) # Make sure all components are returned - self.assertEqual(len(optimizer_choice.get_components().keys()), 4) + assert len(optimizer_choice.get_components().keys()) == 4 # For every optimizer in the components, make sure # that it complies with the scikit learn estimator. @@ -229,7 +229,7 @@ def test_every_optimizer_is_valid(self): # Make sure all keys are copied properly for k, v in estimator.get_params().items(): - self.assertIn(k, estimator_clone_params) + assert k in estimator_clone_params # Make sure the params getter of estimator are honored klass = estimator.__class__ @@ -242,7 +242,7 @@ def test_every_optimizer_is_valid(self): for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - self.assertEqual(param1, param2) + assert param1 == param2 def test_get_set_config_space(self): """Make sure that we can setup a valid choice in the optimizer @@ -251,10 +251,8 @@ def test_get_set_config_space(self): cs = optimizer_choice.get_hyperparameter_search_space() # Make sure that all hyperparameters are part of the serach space - self.assertListEqual( - sorted(cs.get_hyperparameter('__choice__').choices), - sorted(list(optimizer_choice.get_components().keys())) - ) + assert sorted(cs.get_hyperparameter('__choice__').choices) == \ + sorted(list(optimizer_choice.get_components().keys())) # Make sure we can properly set some random configs # Whereas just one iteration will make sure the algorithm works, @@ -265,8 +263,7 @@ def test_get_set_config_space(self): config_dict = copy.deepcopy(config.get_dictionary()) optimizer_choice.set_hyperparameters(config) - self.assertEqual(optimizer_choice.choice.__class__, - optimizer_choice.get_components()[config_dict['__choice__']]) + assert optimizer_choice.choice.__class__ == optimizer_choice.get_components()[config_dict['__choice__']] # Then check the choice configuration selected_choice = config_dict.pop('__choice__', None) @@ -274,61 +271,79 @@ def test_get_set_config_space(self): # Remove the selected_choice string from the parameter # so we can query in the object for it key = key.replace(selected_choice + ':', '') - self.assertIn(key, vars(optimizer_choice.choice)) - self.assertEqual(value, optimizer_choice.choice.__dict__[key]) + assert key == vars(optimizer_choice.choice) + assert value == optimizer_choice.choice.__dict__[key] def test_optimizer_add(self): """Makes sure that a component can be added to the CS""" # No third party components to start with - self.assertEqual(len(optimizer_components._addons.components), 0) + assert len(optimizer_components._addons.components) == 0 # Then make sure the optimizer can be added and query'ed optimizer_components.add_optimizer(DummyOptimizer) - self.assertEqual(len(optimizer_components._addons.components), 1) + assert len(optimizer_components._addons.components) == 1 cs = OptimizerChoice(dataset_properties={}).get_hyperparameter_search_space() - self.assertIn('DummyOptimizer', str(cs)) + assert 'DummyOptimizer' in str(cs) -class NetworkBackboneTest(unittest.TestCase): +class TestNetworkBackbone: def test_all_backbones_available(self): backbone_choice = NetworkBackboneChoice(dataset_properties={}) - self.assertEqual(len(backbone_choice.get_components().keys()), 8) + assert len(backbone_choice.get_components().keys()) == 8 - def test_dummy_forward_backward_pass(self): + @pytest.mark.parametrize('task_type_input_shape', [(constants.IMAGE_CLASSIFICATION, (3, 64, 64)), + (constants.IMAGE_REGRESSION, (3, 64, 64)), + (constants.TIMESERIES_CLASSIFICATION, (32, 6)), + (constants.TIMESERIES_REGRESSION, (32, 6)), + (constants.TABULAR_CLASSIFICATION, (100,)), + (constants.TABULAR_REGRESSION, (100,))]) + def test_dummy_forward_backward_pass(self, task_type_input_shape): network_backbone_choice = NetworkBackboneChoice(dataset_properties={}) - task_types = {constants.IMAGE_CLASSIFICATION: (3, 64, 64), - constants.IMAGE_REGRESSION: (3, 64, 64), - constants.TIMESERIES_CLASSIFICATION: (32, 6), - constants.TIMESERIES_REGRESSION: (32, 6), - constants.TABULAR_CLASSIFICATION: (100,), - constants.TABULAR_REGRESSION: (100,)} - device = torch.device("cpu") - - for task_type, input_shape in task_types.items(): - dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} - - cs = network_backbone_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) - - # test 10 random configurations - for i in range(10): - config = cs.sample_configuration() - network_backbone_choice.set_hyperparameters(config) - backbone = network_backbone_choice.choice.build_backbone(input_shape=input_shape) - self.assertNotEqual(backbone, None) - backbone = backbone.to(device) - dummy_input = torch.randn((2, *input_shape), dtype=torch.float) - output = backbone(dummy_input) - self.assertNotEqual(output.shape[1:], output) - loss = output.sum() - loss.backward() + # shorten search space as it causes out of memory errors in github actions + updates = HyperparameterSearchSpaceUpdates() + updates.append(node_name='network_backbone', + hyperparameter='ConvNetImageBackbone:num_layers', + value_range=[1, 3], + default_value=2) + updates.append(node_name='network_backbone', + hyperparameter='ConvNetImageBackbone:num_init_filters', + value_range=[8, 16], + default_value=8) + updates.append(node_name='network_backbone', + hyperparameter='DenseNetImageBackbone:num_layers', + value_range=[4, 8], + default_value=6) + updates.append(node_name='network_backbone', + hyperparameter='DenseNetImageBackbone:num_blocks', + value_range=[1, 2], + default_value=1) + updates.apply([('network_backbone', network_backbone_choice)]) + + task_type, input_shape = task_type_input_shape + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} + + cs = network_backbone_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) + + # test 10 random configurations + for i in range(10): + config = cs.sample_configuration() + network_backbone_choice.set_hyperparameters(config) + backbone = network_backbone_choice.choice.build_backbone(input_shape=input_shape) + assert backbone is not None + backbone = backbone.to(device) + dummy_input = torch.randn((2, *input_shape), dtype=torch.float) + output = backbone(dummy_input) + assert output.shape[1:] != output + loss = output.sum() + loss.backward() def test_every_backbone_is_valid(self): backbone_choice = NetworkBackboneChoice(dataset_properties={}) - self.assertEqual(len(backbone_choice.get_components().keys()), 8) + assert len(backbone_choice.get_components().keys()) == 8 for name, backbone in backbone_choice.get_components().items(): config = backbone.get_hyperparameter_search_space().sample_configuration() @@ -338,7 +353,7 @@ def test_every_backbone_is_valid(self): # Make sure all keys are copied properly for k, v in estimator.get_params().items(): - self.assertIn(k, estimator_clone_params) + assert k in estimator_clone_params # Make sure the params getter of estimator are honored klass = estimator.__class__ @@ -351,7 +366,7 @@ def test_every_backbone_is_valid(self): for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - self.assertEqual(param1, param2) + assert param1 == param2 def test_get_set_config_space(self): """ @@ -371,12 +386,12 @@ def test_get_set_config_space(self): config_dict = copy.deepcopy(config.get_dictionary()) network_backbone_choice.set_hyperparameters(config) - self.assertEqual(network_backbone_choice.choice.__class__, - network_backbone_choice.get_components()[config_dict['__choice__']]) + assert network_backbone_choice.choice.__class__ == \ + network_backbone_choice.get_components()[config_dict['__choice__']] # Then check the choice configuration selected_choice = config_dict.pop('__choice__', None) - self.assertNotEqual(selected_choice, None) + assert selected_choice is not None for key, value in config_dict.items(): # Remove the selected_choice string from the parameter # so we can query in the object for it @@ -384,63 +399,62 @@ def test_get_set_config_space(self): # parameters are dynamic, so they exist in config parameters = vars(network_backbone_choice.choice) parameters.update(vars(network_backbone_choice.choice)['config']) - self.assertIn(key, parameters) - self.assertEqual(value, parameters[key]) + assert key in parameters + assert value == parameters[key] def test_add_network_backbone(self): """Makes sure that a component can be added to the CS""" # No third party components to start with - self.assertEqual(len(base_network_backbone_choice._addons.components), 0) + assert len(base_network_backbone_choice._addons.components) == 0 # Then make sure the backbone can be added base_network_backbone_choice.add_backbone(DummyBackbone) - self.assertEqual(len(base_network_backbone_choice._addons.components), 1) + assert len(base_network_backbone_choice._addons.components) == 1 cs = NetworkBackboneChoice(dataset_properties={}). \ get_hyperparameter_search_space(dataset_properties={"task_type": "tabular_classification"}) - self.assertIn("DummyBackbone", str(cs)) + assert "DummyBackbone" in str(cs) # clear addons base_network_backbone_choice._addons = ThirdPartyComponents(NetworkBackboneComponent) -class NetworkHeadTest(unittest.TestCase): +class TestNetworkHead: def test_all_heads_available(self): network_head_choice = NetworkHeadChoice(dataset_properties={}) - self.assertEqual(len(network_head_choice.get_components().keys()), 2) + assert len(network_head_choice.get_components().keys()) == 2 - def test_dummy_forward_backward_pass(self): + @pytest.mark.parametrize('task_type_input_output_shape', [(constants.IMAGE_CLASSIFICATION, (3, 64, 64), (5,)), + (constants.IMAGE_REGRESSION, (3, 64, 64), (1,)), + (constants.TIMESERIES_CLASSIFICATION, (32, 6), (5,)), + (constants.TIMESERIES_REGRESSION, (32, 6), (1,)), + (constants.TABULAR_CLASSIFICATION, (100,), (5,)), + (constants.TABULAR_REGRESSION, (100,), (1,))]) + def test_dummy_forward_backward_pass(self, task_type_input_output_shape): network_head_choice = NetworkHeadChoice(dataset_properties={}) - task_types = {constants.IMAGE_CLASSIFICATION: ((3, 64, 64), (5,)), - constants.IMAGE_REGRESSION: ((3, 64, 64), (1,)), - constants.TIMESERIES_CLASSIFICATION: ((32, 6), (5,)), - constants.TIMESERIES_REGRESSION: ((32, 6), (1,)), - constants.TABULAR_CLASSIFICATION: ((100,), (5,)), - constants.TABULAR_REGRESSION: ((100,), (1,))} - + task_type, input_shape, output_shape = task_type_input_output_shape device = torch.device("cpu") - for task_type, (input_shape, output_shape) in task_types.items(): - dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} - if task_type in constants.CLASSIFICATION_TASKS: - dataset_properties["num_classes"] = output_shape[0] + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} + if task_type in constants.CLASSIFICATION_TASKS: + dataset_properties["num_classes"] = output_shape[0] - cs = network_head_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) - # test 10 random configurations - for i in range(10): - config = cs.sample_configuration() - network_head_choice.set_hyperparameters(config) - head = network_head_choice.choice.build_head(input_shape=input_shape, - output_shape=output_shape) - self.assertNotEqual(head, None) - head = head.to(device) - dummy_input = torch.randn((2, *input_shape), dtype=torch.float) - output = head(dummy_input) - self.assertEqual(output.shape[1:], output_shape) - loss = output.sum() - loss.backward() + cs = network_head_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) + # test 10 random configurations + for i in range(10): + config = cs.sample_configuration() + network_head_choice.set_hyperparameters(config) + head = network_head_choice.choice.build_head(input_shape=input_shape, + output_shape=output_shape) + assert head is not None + head = head.to(device) + dummy_input = torch.randn((2, *input_shape), dtype=torch.float) + output = head(dummy_input) + assert output.shape[1:] == output_shape + loss = output.sum() + loss.backward() def test_every_head_is_valid(self): """ @@ -464,7 +478,7 @@ def test_every_head_is_valid(self): # Make sure all keys are copied properly for k, v in estimator.get_params().items(): - self.assertIn(k, estimator_clone_params) + assert k in estimator_clone_params # Make sure the params getter of estimator are honored klass = estimator.__class__ @@ -477,7 +491,7 @@ def test_every_head_is_valid(self): for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - self.assertEqual(param1, param2) + assert param1 == param2 def test_get_set_config_space(self): """ @@ -497,12 +511,12 @@ def test_get_set_config_space(self): config_dict = copy.deepcopy(config.get_dictionary()) network_head_choice.set_hyperparameters(config) - self.assertEqual(network_head_choice.choice.__class__, - network_head_choice.get_components()[config_dict['__choice__']]) + assert network_head_choice.choice.__class__ == \ + network_head_choice.get_components()[config_dict['__choice__']] # Then check the choice configuration selected_choice = config_dict.pop('__choice__', None) - self.assertNotEqual(selected_choice, None) + assert selected_choice is not None for key, value in config_dict.items(): # Remove the selected_choice string from the parameter # so we can query in the object for it @@ -510,27 +524,27 @@ def test_get_set_config_space(self): # parameters are dynamic, so they exist in config parameters = vars(network_head_choice.choice) parameters.update(vars(network_head_choice.choice)['config']) - self.assertIn(key, parameters) - self.assertEqual(value, parameters[key]) + assert key in parameters + assert value == parameters[key] def test_add_network_head(self): """Makes sure that a component can be added to the CS""" # No third party components to start with - self.assertEqual(len(base_network_head_choice._addons.components), 0) + assert len(base_network_head_choice._addons.components) == 0 # Then make sure the head can be added base_network_head_choice.add_head(DummyHead) - self.assertEqual(len(base_network_head_choice._addons.components), 1) + assert len(base_network_head_choice._addons.components) == 1 cs = NetworkHeadChoice(dataset_properties={}). \ get_hyperparameter_search_space(dataset_properties={"task_type": "tabular_classification"}) - self.assertIn("DummyHead", str(cs)) + assert "DummyHead" in str(cs) # clear addons base_network_head_choice._addons = ThirdPartyComponents(NetworkHeadComponent) -class NetworkInitializerTest(unittest.TestCase): +class TestNetworkInitializer: def test_every_network_initializer_is_valid(self): """ Makes sure that every network_initializer is a valid estimator. @@ -542,7 +556,7 @@ def test_every_network_initializer_is_valid(self): network_initializer_choice = NetworkInitializerChoice(dataset_properties={}) # Make sure all components are returned - self.assertEqual(len(network_initializer_choice.get_components().keys()), 5) + assert len(network_initializer_choice.get_components().keys()) == 5 # For every optimizer in the components, make sure # that it complies with the scikit learn estimator. @@ -556,7 +570,7 @@ def test_every_network_initializer_is_valid(self): # Make sure all keys are copied properly for k, v in estimator.get_params().items(): - self.assertIn(k, estimator_clone_params) + assert k in estimator_clone_params # Make sure the params getter of estimator are honored klass = estimator.__class__ @@ -569,7 +583,7 @@ def test_every_network_initializer_is_valid(self): for name in new_object_params: param1 = new_object_params[name] param2 = params_set[name] - self.assertEqual(param1, param2) + assert param1 == param2 def test_get_set_config_space(self): """Make sure that we can setup a valid choice in the network_initializer @@ -578,10 +592,8 @@ def test_get_set_config_space(self): cs = network_initializer_choice.get_hyperparameter_search_space() # Make sure that all hyperparameters are part of the serach space - self.assertListEqual( - sorted(cs.get_hyperparameter('__choice__').choices), - sorted(list(network_initializer_choice.get_components().keys())) - ) + assert sorted(cs.get_hyperparameter('__choice__').choices) == \ + sorted(list(network_initializer_choice.get_components().keys())) # Make sure we can properly set some random configs # Whereas just one iteration will make sure the algorithm works, @@ -592,8 +604,8 @@ def test_get_set_config_space(self): config_dict = copy.deepcopy(config.get_dictionary()) network_initializer_choice.set_hyperparameters(config) - self.assertEqual(network_initializer_choice.choice.__class__, - network_initializer_choice.get_components()[config_dict['__choice__']]) + assert network_initializer_choice.choice.__class__ == \ + network_initializer_choice.get_components()[config_dict['__choice__']] # Then check the choice configuration selected_choice = config_dict.pop('__choice__', None) @@ -601,20 +613,16 @@ def test_get_set_config_space(self): # Remove the selected_choice string from the parameter # so we can query in the object for it key = key.replace(selected_choice + ':', '') - self.assertIn(key, vars(network_initializer_choice.choice)) - self.assertEqual(value, network_initializer_choice.choice.__dict__[key]) + assert key in vars(network_initializer_choice.choice) + assert value == network_initializer_choice.choice.__dict__[key] def test_network_initializer_add(self): """Makes sure that a component can be added to the CS""" # No third party components to start with - self.assertEqual(len(network_initializer_components._addons.components), 0) + assert len(network_initializer_components._addons.components) == 0 # Then make sure the network_initializer can be added and query'ed network_initializer_components.add_network_initializer(DummyNetworkInitializer) - self.assertEqual(len(network_initializer_components._addons.components), 1) + assert len(network_initializer_components._addons.components) == 1 cs = NetworkInitializerChoice(dataset_properties={}).get_hyperparameter_search_space() - self.assertIn('DummyNetworkInitializer', str(cs)) - - -if __name__ == '__main__': - unittest.main() + assert 'DummyNetworkInitializer' in str(cs) diff --git a/test/test_pipeline/components/test_setup_image_augmenter.py b/test/test_pipeline/components/setup/test_setup_image_augmenter.py similarity index 100% rename from test/test_pipeline/components/test_setup_image_augmenter.py rename to test/test_pipeline/components/setup/test_setup_image_augmenter.py diff --git a/test/test_pipeline/components/test_setup_networks.py b/test/test_pipeline/components/setup/test_setup_networks.py similarity index 100% rename from test/test_pipeline/components/test_setup_networks.py rename to test/test_pipeline/components/setup/test_setup_networks.py diff --git a/test/test_pipeline/components/test_setup_preprocessing_node.py b/test/test_pipeline/components/setup/test_setup_preprocessing_node.py similarity index 100% rename from test/test_pipeline/components/test_setup_preprocessing_node.py rename to test/test_pipeline/components/setup/test_setup_preprocessing_node.py diff --git a/test/test_pipeline/components/test_setup_traditional_classification.py b/test/test_pipeline/components/setup/test_setup_traditional_classification.py similarity index 100% rename from test/test_pipeline/components/test_setup_traditional_classification.py rename to test/test_pipeline/components/setup/test_setup_traditional_classification.py diff --git a/test/test_pipeline/components/training/__init__.py b/test/test_pipeline/components/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/test_pipeline/components/test_feature_data_loader.py b/test/test_pipeline/components/training/test_feature_data_loader.py similarity index 100% rename from test/test_pipeline/components/test_feature_data_loader.py rename to test/test_pipeline/components/training/test_feature_data_loader.py diff --git a/test/test_pipeline/components/test_image_data_loader.py b/test/test_pipeline/components/training/test_image_data_loader.py similarity index 100% rename from test/test_pipeline/components/test_image_data_loader.py rename to test/test_pipeline/components/training/test_image_data_loader.py diff --git a/test/test_pipeline/components/test_training.py b/test/test_pipeline/components/training/test_training.py similarity index 98% rename from test/test_pipeline/components/test_training.py rename to test/test_pipeline/components/training/test_training.py index 081c6bcaa..e17dfce3d 100644 --- a/test/test_pipeline/components/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -27,7 +27,7 @@ ) sys.path.append(os.path.dirname(__file__)) -from base import BaseTraining # noqa (E402: module level import not at top of file) +from test.test_pipeline.components.base import BaseTraining # noqa (E402: module level import not at top of file) class BaseDataLoaderTest(unittest.TestCase): @@ -128,6 +128,7 @@ def test_evaluate(self): Makes sure we properly evaluate data, returning a proper loss and metric """ + (trainer, model, optimizer, @@ -156,7 +157,6 @@ def test_evaluate(self): class StandardTrainerTest(BaseTraining, unittest.TestCase): - def test_regression_epoch_training(self): (trainer, _, diff --git a/test/test_pipeline/test_losses.py b/test/test_pipeline/test_losses.py index 6cc669161..9f23f3e9f 100644 --- a/test/test_pipeline/test_losses.py +++ b/test/test_pipeline/test_losses.py @@ -2,9 +2,9 @@ import torch from torch import nn +from torch.nn.modules.loss import _Loss as Loss -from autoPyTorch.constants import STRING_TO_OUTPUT_TYPES -from autoPyTorch.pipeline.components.training.losses import get_loss_instance +from autoPyTorch.pipeline.components.training.losses import get_loss, losses from autoPyTorch.utils.implementations import get_loss_weight_strategy @@ -14,7 +14,7 @@ 'continuous']) def test_get_no_name(output_type): dataset_properties = {'task_type': 'tabular_classification', 'output_type': output_type} - loss = get_loss_instance(dataset_properties) + loss = get_loss(dataset_properties) assert isinstance(loss(), nn.Module) @@ -23,7 +23,7 @@ def test_get_no_name(output_type): def test_get_name(output_type_name): output_type, name = output_type_name dataset_properties = {'task_type': 'tabular_classification', 'output_type': output_type} - loss = get_loss_instance(dataset_properties, name)() + loss = get_loss(dataset_properties, name)() assert isinstance(loss, nn.Module) assert str(loss) == f"{name}()" @@ -32,29 +32,37 @@ def test_get_name_error(): dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'multiclass'} name = 'BCELoss' with pytest.raises(ValueError, match=r"Invalid name entered for task [a-z]+_[a-z]+, "): - get_loss_instance(dataset_properties, name) + get_loss(dataset_properties, name) @pytest.mark.parametrize('weighted', [True, False]) -def test_losses(weighted): - list_properties = [{'task_type': 'tabular_classification', 'output_type': 'multiclass'}, - {'task_type': 'tabular_classification', 'output_type': 'binary'}, - {'task_type': 'tabular_regression', 'output_type': 'continuous'}] - pred_cross_entropy = torch.randn(4, 4, requires_grad=True) - list_predictions = [pred_cross_entropy, torch.empty(4).random_(2), torch.randn(4)] - list_names = [None, 'BCEWithLogitsLoss', None] - list_targets = [torch.empty(4, dtype=torch.long).random_(4), torch.empty(4).random_(2), torch.randn(4)] - labels = [torch.empty(100, dtype=torch.long).random_(4), torch.empty(100, dtype=torch.long).random_(2), None] - for dataset_properties, pred, target, name, label in zip(list_properties, list_predictions, - list_targets, list_names, labels): - loss = get_loss_instance(dataset_properties=dataset_properties, name=name) - weights = None - if bool(weighted) and 'classification' in dataset_properties['task_type']: - strategy = get_loss_weight_strategy(output_type=STRING_TO_OUTPUT_TYPES[dataset_properties['output_type']]) - weights = strategy(y=label) - weights = torch.from_numpy(weights) - weights = weights.type(torch.FloatTensor) - kwargs = {'pos_weight': weights} if 'binary' in dataset_properties['output_type'] else {'weight': weights} - loss = loss() if weights is None else loss(**kwargs) - score = loss(pred, target) - assert isinstance(score, torch.Tensor) +@pytest.mark.parametrize('loss_details', ['loss_cross_entropy_multiclass', + 'loss_cross_entropy_binary', + 'loss_bce', + 'loss_mse'], indirect=True) +def test_losses(weighted, loss_details): + dataset_properties, predictions, name, targets, labels = loss_details + loss = get_loss(dataset_properties=dataset_properties, name=name) + weights = None + if bool(weighted) and 'classification' in dataset_properties['task_type']: + strategy = get_loss_weight_strategy(loss) + weights = strategy(y=labels) + weights = torch.from_numpy(weights) + weights = weights.type(torch.FloatTensor) + kwargs = {'pos_weight': weights} if loss.__name__ == 'BCEWithLogitsLoss' else {'weight': weights} + loss = loss() if weights is None else loss(**kwargs) + score = loss(predictions, targets) + assert isinstance(score, torch.Tensor) + # Ensure it is a one element tensor + assert len(score.size()) == 0 + + +def test_loss_dict(): + assert 'classification' in losses.keys() + assert 'regression' in losses.keys() + for task in losses.values(): + for loss in task.values(): + assert 'module' in loss.keys() + assert isinstance(loss['module'](), Loss) + assert 'supported_output_types' in loss.keys() + assert isinstance(loss['supported_output_types'], list)