diff --git a/docs/en/api/core.rst b/docs/en/api/core.rst index 0f83483ca35..83e1dbf42f0 100644 --- a/docs/en/api/core.rst +++ b/docs/en/api/core.rst @@ -48,6 +48,7 @@ Hook ClassNumCheckHook PreciseBNHook CosineAnnealingCooldownLrUpdaterHook + MMClsWandbHook Optimizers diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index a2247283b3b..131f9d756fa 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -8,9 +8,8 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook, build_optimizer, build_runner, get_dist_info) -from mmcv.runner.hooks import DistEvalHook, EvalHook -from mmcls.core import DistOptimizerHook +from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook from mmcls.datasets import build_dataloader, build_dataset from mmcls.utils import get_root_logger diff --git a/mmcls/core/evaluation/__init__.py b/mmcls/core/evaluation/__init__.py index 1e641a65011..dd4e57ccf0a 100644 --- a/mmcls/core/evaluation/__init__.py +++ b/mmcls/core/evaluation/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .eval_hooks import DistEvalHook, EvalHook from .eval_metrics import (calculate_confusion_matrix, f1_score, precision, precision_recall_f1, recall, support) from .mean_ap import average_precision, mAP @@ -6,5 +7,6 @@ __all__ = [ 'precision', 'recall', 'f1_score', 'support', 'average_precision', 'mAP', - 'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1' + 'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1', + 'EvalHook', 'DistEvalHook' ] diff --git a/mmcls/core/evaluation/eval_hooks.py b/mmcls/core/evaluation/eval_hooks.py new file mode 100644 index 00000000000..412eab4fa9a --- /dev/null +++ b/mmcls/core/evaluation/eval_hooks.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import torch.distributed as dist +from mmcv.runner import DistEvalHook as BaseDistEvalHook +from mmcv.runner import EvalHook as BaseEvalHook +from torch.nn.modules.batchnorm import _BatchNorm + + +class EvalHook(BaseEvalHook): + """Non-Distributed evaluation hook. + + Comparing with the ``EvalHook`` in MMCV, this hook will save the latest + evaluation results as an attribute for other hooks to use (like + `MMClsWandbHook`). + """ + + def __init__(self, dataloader, **kwargs): + super(EvalHook, self).__init__(dataloader, **kwargs) + self.latest_results = None + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + results = self.test_fn(runner.model, self.dataloader) + self.latest_results = results + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to save + # the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) + + +class DistEvalHook(BaseDistEvalHook): + """Non-Distributed evaluation hook. + + Comparing with the ``EvalHook`` in MMCV, this hook will save the latest + evaluation results as an attribute for other hooks to use (like + `MMClsWandbHook`). + """ + + def __init__(self, dataloader, **kwargs): + super(DistEvalHook, self).__init__(dataloader, **kwargs) + self.latest_results = None + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + results = self.test_fn( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + self.latest_results = results + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + # the key_score may be `None` so it needs to skip the action to + # save the best checkpoint + if self.save_best and key_score: + self._save_ckpt(runner, key_score) diff --git a/mmcls/core/hook/__init__.py b/mmcls/core/hook/__init__.py index 2c2dbdfc0e8..4212dcf9ccb 100644 --- a/mmcls/core/hook/__init__.py +++ b/mmcls/core/hook/__init__.py @@ -2,8 +2,9 @@ from .class_num_check_hook import ClassNumCheckHook from .lr_updater import CosineAnnealingCooldownLrUpdaterHook from .precise_bn_hook import PreciseBNHook +from .wandblogger_hook import MMClsWandbHook __all__ = [ 'ClassNumCheckHook', 'PreciseBNHook', - 'CosineAnnealingCooldownLrUpdaterHook' + 'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook' ] diff --git a/mmcls/core/hook/wandblogger_hook.py b/mmcls/core/hook/wandblogger_hook.py new file mode 100644 index 00000000000..61ccfe90d6e --- /dev/null +++ b/mmcls/core/hook/wandblogger_hook.py @@ -0,0 +1,340 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import numpy as np +from mmcv.runner import HOOKS, BaseRunner +from mmcv.runner.dist_utils import master_only +from mmcv.runner.hooks.checkpoint import CheckpointHook +from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook +from mmcv.runner.hooks.logger.wandb import WandbLoggerHook + + +@HOOKS.register_module() +class MMClsWandbHook(WandbLoggerHook): + """Enhanced Wandb logger hook for classification. + + Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not + only automatically log all information in ``log_buffer`` but also log + the following extra information. + + - **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at + every checkpoint interval will be saved as W&B Artifacts. This depends on + the : class:`mmcv.runner.CheckpointHook` whose priority is higher than + this hook. Please refer to + https://docs.wandb.ai/guides/artifacts/model-versioning to learn more + about model versioning with W&B Artifacts. + + - **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every + checkpoint artifact will have a metadata associated with it. The metadata + contains the evaluation metrics computed on validation data with that + checkpoint along with the current epoch/iter. It depends on + :class:`EvalHook` whose priority is higher than this hook. + + - **Evaluation**: At every interval, this hook logs the model prediction as + interactive W&B Tables. The number of samples logged is given by + ``num_eval_images``. Currently, this hook logs the predicted labels along + with the ground truth at every evaluation interval. This depends on the + :class:`EvalHook` whose priority is higher than this hook. Also note that + the data is just logged once and subsequent evaluation tables uses + reference to the logged data to save memory usage. Please refer to + https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables. + + Here is a config example: + + .. code:: python + + checkpoint_config = dict(interval=10) + + # To log checkpoint metadata, the interval of checkpoint saving should + # be divisible by the interval of evaluation. + evaluation = dict(interval=5) + + log_config = dict( + ... + hooks=[ + ... + dict(type='MMClsWandbHook', + init_kwargs={ + 'entity': "YOUR_ENTITY", + 'project': "YOUR_PROJECT_NAME" + }, + log_checkpoint=True, + log_checkpoint_metadata=True, + num_eval_images=100) + ]) + + Args: + init_kwargs (dict): A dict passed to wandb.init to initialize + a W&B run. Please refer to https://docs.wandb.ai/ref/python/init + for possible key-value pairs. + interval (int): Logging interval (every k iterations). Defaults to 10. + log_checkpoint (bool): Save the checkpoint at every checkpoint interval + as W&B Artifacts. Use this for model versioning where each version + is a checkpoint. Defaults to False. + log_checkpoint_metadata (bool): Log the evaluation metrics computed + on the validation data with the checkpoint, along with current + epoch as a metadata to that checkpoint. + Defaults to True. + num_eval_images (int): The number of validation images to be logged. + If zero, the evaluation won't be logged. Defaults to 100. + """ + + def __init__(self, + init_kwargs=None, + interval=10, + log_checkpoint=False, + log_checkpoint_metadata=False, + num_eval_images=100, + **kwargs): + super(MMClsWandbHook, self).__init__(init_kwargs, interval, **kwargs) + + self.log_checkpoint = log_checkpoint + self.log_checkpoint_metadata = ( + log_checkpoint and log_checkpoint_metadata) + self.num_eval_images = num_eval_images + self.log_evaluation = (num_eval_images > 0) + self.ckpt_hook: CheckpointHook = None + self.eval_hook: EvalHook = None + + @master_only + def before_run(self, runner: BaseRunner): + super(MMClsWandbHook, self).before_run(runner) + + # Inspect CheckpointHook and EvalHook + for hook in runner.hooks: + if isinstance(hook, CheckpointHook): + self.ckpt_hook = hook + if isinstance(hook, (EvalHook, DistEvalHook)): + self.eval_hook = hook + + # Check conditions to log checkpoint + if self.log_checkpoint: + if self.ckpt_hook is None: + self.log_checkpoint = False + self.log_checkpoint_metadata = False + runner.logger.warning( + 'To log checkpoint in MMClsWandbHook, `CheckpointHook` is' + 'required, please check hooks in the runner.') + else: + self.ckpt_interval = self.ckpt_hook.interval + + # Check conditions to log evaluation + if self.log_evaluation or self.log_checkpoint_metadata: + if self.eval_hook is None: + self.log_evaluation = False + self.log_checkpoint_metadata = False + runner.logger.warning( + 'To log evaluation or checkpoint metadata in ' + 'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls ' + 'is required, please check whether the validation ' + 'is enabled.') + else: + self.eval_interval = self.eval_hook.interval + self.val_dataset = self.eval_hook.dataloader.dataset + if (self.log_evaluation + and self.num_eval_images > len(self.val_dataset)): + self.num_eval_images = len(self.val_dataset) + runner.logger.warning( + f'The num_eval_images ({self.num_eval_images}) is ' + 'greater than the total number of validation samples ' + f'({len(self.val_dataset)}). The complete validation ' + 'dataset will be logged.') + + # Check conditions to log checkpoint metadata + if self.log_checkpoint_metadata: + assert self.ckpt_interval % self.eval_interval == 0, \ + 'To log checkpoint metadata in MMClsWandbHook, the interval ' \ + f'of checkpoint saving ({self.ckpt_interval}) should be ' \ + 'divisible by the interval of evaluation ' \ + f'({self.eval_interval}).' + + # Initialize evaluation table + if self.log_evaluation: + # Initialize data table + self._init_data_table() + # Add ground truth to the data table + self._add_ground_truth() + # Log ground truth data + self._log_data_table() + + @master_only + def after_train_epoch(self, runner): + super(MMClsWandbHook, self).after_train_epoch(runner) + + if not self.by_epoch: + return + + # Save checkpoint and metadata + if (self.log_checkpoint + and self.every_n_epochs(runner, self.ckpt_interval) + or (self.ckpt_hook.save_last and self.is_last_epoch(runner))): + if self.log_checkpoint_metadata and self.eval_hook: + metadata = { + 'epoch': runner.epoch + 1, + **self._get_eval_results() + } + else: + metadata = None + aliases = [f'epoch_{runner.epoch+1}', 'latest'] + model_path = osp.join(self.ckpt_hook.out_dir, + f'epoch_{runner.epoch+1}.pth') + self._log_ckpt_as_artifact(model_path, aliases, metadata) + + # Save prediction table + if self.log_evaluation and self.eval_hook._should_evaluate(runner): + results = self.eval_hook.latest_results + # Initialize evaluation table + self._init_pred_table() + # Add predictions to evaluation table + self._add_predictions(results, runner.epoch + 1) + # Log the evaluation table + self._log_eval_table(runner.epoch + 1) + + @master_only + def after_train_iter(self, runner): + if self.get_mode(runner) == 'train': + # An ugly patch. The iter-based eval hook will call the + # `after_train_iter` method of all logger hooks before evaluation. + # Use this trick to skip that call. + # Don't call super method at first, it will clear the log_buffer + return super(MMClsWandbHook, self).after_train_iter(runner) + else: + super(MMClsWandbHook, self).after_train_iter(runner) + + if self.by_epoch: + return + + # Save checkpoint and metadata + if (self.log_checkpoint + and self.every_n_iters(runner, self.ckpt_interval) + or (self.ckpt_hook.save_last and self.is_last_iter(runner))): + if self.log_checkpoint_metadata and self.eval_hook: + metadata = { + 'iter': runner.iter + 1, + **self._get_eval_results() + } + else: + metadata = None + aliases = [f'iter_{runner.iter+1}', 'latest'] + model_path = osp.join(self.ckpt_hook.out_dir, + f'iter_{runner.iter+1}.pth') + self._log_ckpt_as_artifact(model_path, aliases, metadata) + + # Save prediction table + if self.log_evaluation and self.eval_hook._should_evaluate(runner): + results = self.eval_hook.latest_results + # Initialize evaluation table + self._init_pred_table() + # Log predictions + self._add_predictions(results, runner.iter + 1) + # Log the table + self._log_eval_table(runner.iter + 1) + + @master_only + def after_run(self, runner): + self.wandb.finish() + + def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None): + """Log model checkpoint as W&B Artifact. + + Args: + model_path (str): Path of the checkpoint to log. + aliases (list): List of the aliases associated with this artifact. + metadata (dict, optional): Metadata associated with this artifact. + """ + model_artifact = self.wandb.Artifact( + f'run_{self.wandb.run.id}_model', type='model', metadata=metadata) + model_artifact.add_file(model_path) + self.wandb.log_artifact(model_artifact, aliases=aliases) + + def _get_eval_results(self): + """Get model evaluation results.""" + results = self.eval_hook.latest_results + eval_results = self.val_dataset.evaluate( + results, logger='silent', **self.eval_hook.eval_kwargs) + return eval_results + + def _init_data_table(self): + """Initialize the W&B Tables for validation data.""" + columns = ['image_name', 'image', 'ground_truth'] + self.data_table = self.wandb.Table(columns=columns) + + def _init_pred_table(self): + """Initialize the W&B Tables for model evaluation.""" + columns = ['epoch'] if self.by_epoch else ['iter'] + columns += ['image_name', 'image', 'ground_truth', 'prediction' + ] + list(self.val_dataset.CLASSES) + self.eval_table = self.wandb.Table(columns=columns) + + def _add_ground_truth(self): + # Get image loading pipeline + from mmcls.datasets.pipelines import LoadImageFromFile + img_loader = None + for t in self.val_dataset.pipeline.transforms: + if isinstance(t, LoadImageFromFile): + img_loader = t + + CLASSES = self.val_dataset.CLASSES + self.eval_image_indexs = np.arange(len(self.val_dataset)) + # Set seed so that same validation set is logged each time. + np.random.seed(42) + np.random.shuffle(self.eval_image_indexs) + self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images] + + for idx in self.eval_image_indexs: + img_info = self.val_dataset.data_infos[idx] + if img_loader is not None: + img_info = img_loader(img_info) + # Get image and convert from BGR to RGB + image = img_info['img'][..., ::-1] + else: + # For CIFAR dataset. + image = img_info['img'] + image_name = img_info.get('filename', f'img_{idx}') + gt_label = img_info.get('gt_label').item() + + self.data_table.add_data(image_name, self.wandb.Image(image), + CLASSES[gt_label]) + + def _add_predictions(self, results, idx): + table_idxs = self.data_table_ref.get_index() + assert len(table_idxs) == len(self.eval_image_indexs) + + for ndx, eval_image_index in enumerate(self.eval_image_indexs): + result = results[eval_image_index] + + self.eval_table.add_data( + idx, self.data_table_ref.data[ndx][0], + self.data_table_ref.data[ndx][1], + self.data_table_ref.data[ndx][2], + self.val_dataset.CLASSES[np.argmax(result)], *tuple(result)) + + def _log_data_table(self): + """Log the W&B Tables for validation data as artifact and calls + `use_artifact` on it so that the evaluation table can use the reference + of already uploaded images. + + This allows the data to be uploaded just once. + """ + data_artifact = self.wandb.Artifact('val', type='dataset') + data_artifact.add(self.data_table, 'val_data') + + self.wandb.run.use_artifact(data_artifact) + data_artifact.wait() + + self.data_table_ref = data_artifact.get('val_data') + + def _log_eval_table(self, idx): + """Log the W&B Tables for model evaluation. + + The table will be logged multiple times creating new version. Use this + to compare models at different intervals interactively. + """ + pred_artifact = self.wandb.Artifact( + f'run_{self.wandb.run.id}_pred', type='evaluation') + pred_artifact.add(self.eval_table, 'eval_data') + if self.by_epoch: + aliases = ['latest', f'epoch_{idx}'] + else: + aliases = ['latest', f'iter_{idx}'] + self.wandb.run.log_artifact(pred_artifact, aliases=aliases)