diff --git a/CHANGELOG.md b/CHANGELOG.md index 35f86f040..3cecfce8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ via the `--inference_on_val_set` flag. gets uploaded to AzureML, by skipping all test folders. - ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameter `extra_downloaded_run_id` has been renamed to `pretraining_run_checkpoints`. +- ([#526](https://github.com/microsoft/InnerEye-DeepLearning/pull/526)) Updated Covid config to use a multiclass + formulation. Moved functions `create_metric_computers` and `compute_and_log_metrics` from `ScalarLightning` to + `ScalarModelBase`. ### Fixed - ([#537](https://github.com/microsoft/InnerEye-DeepLearning/pull/537)) Print warning if inference is disabled but comparison requested. @@ -54,6 +57,9 @@ in inference-only runs when using lightning containers. - ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline. - ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and `weights_url` can no longer be used to initialize a training run, only inference runs. +- ([#526](https://github.com/microsoft/InnerEye-DeepLearning/pull/526)) Removed `get_posthoc_label_transform` in + class `ScalarModelBase`. Instead, functions `get_loss_function` and `compute_and_log_metrics` in + `ScalarModelBase` can be implemented to compute the loss and metrics in a task-specific manner. ### Deprecated diff --git a/InnerEye/ML/configs/classification/CovidHierarchicalModel.py b/InnerEye/ML/configs/classification/CovidModel.py similarity index 73% rename from InnerEye/ML/configs/classification/CovidHierarchicalModel.py rename to InnerEye/ML/configs/classification/CovidModel.py index f98a27a84..96df7196a 100644 --- a/InnerEye/ML/configs/classification/CovidHierarchicalModel.py +++ b/InnerEye/ML/configs/classification/CovidModel.py @@ -1,21 +1,23 @@ -import codecs import logging -import pickle import random import math from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, List import PIL +import numpy as np import pandas as pd import param import torch + from PIL import Image +from torch.nn import ModuleList, ModuleDict from pytorch_lightning import LightningModule from torchvision.transforms import Compose from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path +from InnerEye.Common.metrics_constants import LoggingColumns from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName @@ -28,21 +30,19 @@ from InnerEye.ML.deep_learning_config import LRSchedulerType, MultiprocessingStartMethod, \ OptimizerType +from InnerEye.ML.lightning_metrics import Accuracy05 +from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode from InnerEye.ML.model_testing import MODEL_OUTPUT_CSV -from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType -from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty - +from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase from InnerEye.ML.utils.run_recovery import RunRecovery from InnerEye.ML.utils.split_dataset import DatasetSplits - -from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID -from InnerEye.Common import fixed_paths as fixed_paths_innereye +from InnerEye.ML.metrics_dict import MetricsDict, DataframeLogger -class CovidHierarchicalModel(ScalarModelBase): +class CovidModel(ScalarModelBase): """ Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model. @@ -50,7 +50,7 @@ class CovidHierarchicalModel(ScalarModelBase): --pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your machine and load the corresponding pretrained model. - To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use hte + To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use the --name_of_checkpoint argument. """ use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL." @@ -64,8 +64,7 @@ class CovidHierarchicalModel(ScalarModelBase): "is assumed to contain unique ids.") def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any): - super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'], - loss_type=ScalarLoss.CustomClassification, + super().__init__(loss_type=ScalarLoss.CustomClassification, class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'], max_num_gpus=1, azure_dataset_id=covid_dataset_id, @@ -84,7 +83,7 @@ def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any): l_rate_step_gamma=1.0, l_rate_multi_step_milestones=None, should_validate=False) # validate only after adding kwargs - self.num_classes = 3 + self.num_classes = 4 self.add_and_validate(kwargs) def validate(self) -> None: @@ -192,39 +191,53 @@ def _get_ssl_checkpoint_path(self) -> Path: def pre_process_dataset_dataframe(self) -> None: pass - @staticmethod - def get_posthoc_label_transform() -> Callable: - import torch - - def multiclass_to_hierarchical_labels(classes: torch.Tensor) -> torch.Tensor: - classes = classes.clone() - cvx03vs12 = classes[..., 1] + classes[..., 2] - cvx0vs3 = classes[..., 3] - cvx1vs2 = classes[..., 2] - cvx0vs3[cvx03vs12 == 1] = float('nan') # CVX0vs3 only gets gradient for CVX03 - cvx1vs2[cvx03vs12 == 0] = float('nan') # CVX1vs2 only gets gradient for CVX12 - return torch.stack([cvx03vs12, cvx0vs3, cvx1vs2], -1) - - return multiclass_to_hierarchical_labels - @staticmethod def get_loss_function() -> Callable: import torch import torch.nn.functional as F - def nan_bce_with_logits(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - """Compute BCE with logits, ignoring NaN values""" - valid = labels.isfinite() - losses = F.binary_cross_entropy_with_logits(output[valid], labels[valid], reduction='none') - return losses.sum() / labels.shape[0] - - return nan_bce_with_logits + def custom_loss(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + labels = torch.argmax(labels, dim=-1) + return F.cross_entropy(input=output, target=labels, reduction="sum") + + return custom_loss + + def get_post_loss_logits_normalization_function(self) -> Callable: + return torch.nn.Softmax() + + def create_metric_computers(self) -> ModuleDict: + return ModuleDict({MetricsDict.DEFAULT_HUE_KEY: ModuleList([Accuracy05()])}) + + def compute_and_log_metrics(self, + logits: torch.Tensor, + targets: torch.Tensor, + subject_ids: List[str], + is_training: bool, + metrics: ModuleDict, + logger: DataframeLogger, + current_epoch: int, + data_split: ModelExecutionMode) -> None: + posteriors = self.get_post_loss_logits_normalization_function()(logits) + labels = torch.argmax(targets, dim=-1) + metric = metrics[MetricsDict.DEFAULT_HUE_KEY][0] + metric(posteriors, labels) + + per_subject_outputs = zip(subject_ids, posteriors.tolist(), targets.tolist()) + for subject, model_output, target in per_subject_outputs: + for i in range(len(self.target_names)): + logger.add_record({ + LoggingColumns.Epoch.value: current_epoch, + LoggingColumns.Patient.value: subject, + LoggingColumns.Hue.value: self.target_names[i], + LoggingColumns.ModelOutput.value: model_output[i], + LoggingColumns.Label.value: target[i], + LoggingColumns.DataSplit.value: data_split.value + }) def generate_custom_report(self, report_dir: Path, model_proc: ModelProcessing) -> Path: """ - Generate a custom report for the CovidDataset Hierarchical model. At the moment, this report will read the - file model_output.csv generated for the training, validation or test sets and compute a 4 class accuracy - and confusion matrix based on this. + Generate a custom report for the Covid model. This report will read the file model_output.csv generated for + the training, validation or test sets and compute the multiclass accuracy based on this. :param report_dir: Directory report is to be written to :param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different paths for single vs ensemble runs.) @@ -234,24 +247,37 @@ def get_output_csv_path(mode: ModelExecutionMode) -> Path: p = get_best_epoch_results_path(mode=mode, model_proc=model_proc) return self.outputs_folder / p / MODEL_OUTPUT_CSV + def get_labels_and_predictions(df: pd.DataFrame) -> pd.DataFrame: + labels = [] + predictions = [] + for target in self.target_names: + target_df = df[df[LoggingColumns.Hue.value] == target] + predictions.append(target_df[LoggingColumns.ModelOutput.value]) + labels.append(target_df[LoggingColumns.Label.value]) + + return pd.DataFrame.from_dict({LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]], + LoggingColumns.ModelOutput.value: [np.argmax(predictions)], + LoggingColumns.Label.value: [np.argmax(labels)]}) + + def get_accuracy(df: pd.DataFrame) -> float: + df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_labels_and_predictions).reset_index( + drop=True) + return (df[LoggingColumns.ModelOutput.value] == df[LoggingColumns.Label.value]).mean() # type: ignore + train_metrics = get_output_csv_path(ModelExecutionMode.TRAIN) val_metrics = get_output_csv_path(ModelExecutionMode.VAL) test_metrics = get_output_csv_path(ModelExecutionMode.TEST) - notebook_params = \ - { - 'innereye_path': str(fixed_paths_innereye.repository_root_directory()), - 'train_metrics_csv': str_or_empty(train_metrics), - 'val_metrics_csv': str_or_empty(val_metrics), - 'test_metrics_csv': str_or_empty(test_metrics), - "config": codecs.encode(pickle.dumps(self), "base64").decode(), - "is_crossval_report": False - } - template = Path(__file__).absolute().parent.parent / "reports" / "CovidHierarchicalModelReport.ipynb" - return generate_notebook(template, - notebook_params=notebook_params, - result_notebook=report_dir / get_ipynb_report_name( - f"{self.model_category.value}_hierarchical")) + msg = f"Multiclass Accuracy Train: {get_accuracy(pd.read_csv(train_metrics))}\n" if train_metrics.exists() else "" + msg += f"Multiclass Accuracy Val: {get_accuracy(pd.read_csv(val_metrics))}\n" if val_metrics.exists() else "" + msg += f"Multiclass Accuracy Test: {get_accuracy(pd.read_csv(test_metrics))}\n" if test_metrics.exists() else "" + + report = report_dir / "report.txt" + report.write_text(msg) + + logging.info(msg) + + return report class DicomPreparation: diff --git a/InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb b/InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb deleted file mode 100644 index 6e87854d9..000000000 --- a/InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb +++ /dev/null @@ -1,160 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "%%javascript\n", - "IPython.OutputArea.prototype._should_scroll = function(lines) {\n", - " return false;\n", - "}\n", - "// Stops auto-scrolling so entire output is visible: see https://stackoverflow.com/a/41646403" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": { - "tags": [ - "parameters" - ] - }, - "outputs": [], - "source": [ - "# Default parameter values. They will be overwritten by papermill notebook parameters.\n", - "# This cell must carry the tag \"parameters\" in its metadata.\n", - "from pathlib import Path\n", - "import pickle\n", - "import codecs\n", - "\n", - "innereye_path = Path.cwd().parent.parent.parent.parent\n", - "train_metrics_csv = \"\"\n", - "val_metrics_csv = \"\"\n", - "test_metrics_csv = \"\"\n", - "config = \"\"\n", - "is_crossval_report = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "\n", - "if str(innereye_path) not in sys.path:\n", - " sys.path.append(str(innereye_path))\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "\n", - "config = pickle.loads(codecs.decode(config.encode(), \"base64\"))\n", - "\n", - "from InnerEye.ML.common import ModelExecutionMode\n", - "from InnerEye.ML.reports.notebook_report import print_header\n", - "from InnerEye.ML.configs.reports.covid_hierarchical_model_report import print_metrics_from_csv\n", - "\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "plt.rcParams['figure.figsize'] = (20, 10)\n", - "\n", - "#convert params to Path\n", - "train_metrics_csv = Path(train_metrics_csv)\n", - "val_metrics_csv = Path(val_metrics_csv)\n", - "test_metrics_csv = Path(test_metrics_csv)" - ] - }, - { - "cell_type": "markdown", - "id": "4", - "metadata": {}, - "source": [ - "# Metrics\n", - "## Train Set" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "if train_metrics_csv.is_file():\n", - " print_metrics_from_csv(csv_to_set_optimal_threshold=train_metrics_csv,\n", - " csv_to_compute_metrics=train_metrics_csv,\n", - " config=config, is_crossval_report=is_crossval_report)" - ] - }, - { - "cell_type": "markdown", - "id": "6", - "metadata": {}, - "source": [ - "## Validation Set" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [ - "if val_metrics_csv.is_file():\n", - " print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n", - " csv_to_compute_metrics=val_metrics_csv,\n", - " config=config, is_crossval_report=is_crossval_report)" - ] - }, - { - "cell_type": "markdown", - "id": "8", - "metadata": {}, - "source": [ - "## Test Set" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n", - " print_metrics_from_csv(csv_to_set_optimal_threshold=val_metrics_csv,\n", - " csv_to_compute_metrics=test_metrics_csv,\n", - " config=config, is_crossval_report=is_crossval_report)" - ] - } - ], - "metadata": { - "celltoolbar": "Tags", - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/InnerEye/ML/configs/reports/covid_hierarchical_model_report.py b/InnerEye/ML/configs/reports/covid_hierarchical_model_report.py deleted file mode 100644 index f46e797e6..000000000 --- a/InnerEye/ML/configs/reports/covid_hierarchical_model_report.py +++ /dev/null @@ -1,104 +0,0 @@ -import pandas as pd -import numpy as np - -from pathlib import Path -from sklearn.metrics import accuracy_score, confusion_matrix -from typing import Dict - -from InnerEye.Common.metrics_constants import LoggingColumns -from InnerEye.ML.reports.classification_report import get_labels_and_predictions_from_dataframe, LabelsAndPredictions -from InnerEye.ML.reports.notebook_report import print_table -from InnerEye.ML.scalar_config import ScalarModelBase - -TARGET_NAMES = ['CVX03vs12', 'CVX0vs3', 'CVX1vs2'] -MULTICLASS_HUE_NAME = "Multiclass" - - -def get_label_from_label_dict(label_dict: Dict[str, float]) -> int: - """ - Converts strings CVX03vs12, CVX1vs2, CVX0vs3 to the corresponding class as int. - """ - if label_dict['CVX03vs12'] == 0: - assert np.isnan(label_dict['CVX1vs2']) - if label_dict['CVX0vs3'] == 0: - label = 0 - elif label_dict['CVX0vs3'] == 1: - label = 3 - else: - raise ValueError("CVX0vs3 should be 0 or 1.") - elif label_dict['CVX03vs12'] == 1: - assert np.isnan(label_dict['CVX0vs3']) - if label_dict['CVX1vs2'] == 0: - label = 1 - elif label_dict['CVX1vs2'] == 1: - label = 2 - else: - raise ValueError("CVX1vs2 should be 0 or 1.") - else: - raise ValueError("CVX03vs12 should be 0 or 1.") - return label - - -def get_model_prediction_by_probabilities(output_dict: Dict[str, float]) -> int: - """ - Based on the values for CVX03vs12, CVX0vs3 and CVX1vs2 predicted by the model, predict the CVX scores as followed: - score(CVX0) = [1 - score(CVX03vs12)][1 - score(CVX0vs3)] - score(CVX1) = score(CVX03vs12)[1 - score(CVX1vs2)] - score(CVX2) = score(CVX03vs12)score(CVX1vs2) - score(CVX3) = [1 - score(CVX03vs12)]score(CVX0vs3) - """ - cvx0 = (1 - output_dict['CVX03vs12']) * (1 - output_dict['CVX0vs3']) - cvx3 = (1 - output_dict['CVX03vs12']) * output_dict['CVX0vs3'] - cvx1 = output_dict['CVX03vs12'] * (1 - output_dict['CVX1vs2']) - cvx2 = output_dict['CVX03vs12'] * output_dict['CVX1vs2'] - return np.argmax([cvx0, cvx1, cvx2, cvx3]) - - -def get_dataframe_with_covid_labels(metrics_df: pd.DataFrame) -> pd.DataFrame: - def get_CVX_labels(df: pd.DataFrame) -> pd.DataFrame: - """ - Given a dataframe (with only one subject) with the model outputs for CVX03vs12, CVX0vs3 and CVX1vs2, - returns a corresponding dataframe with scores for CVX0, CVX1, CVX2 and CVX3 for this subject. See - `get_model_prediction_by_probabilities` for details on mapping the model output to CVX labels. - """ - df_by_hue = df[df[LoggingColumns.Hue.value].isin(TARGET_NAMES)].set_index(LoggingColumns.Hue.value) - model_output = get_model_prediction_by_probabilities(df_by_hue[LoggingColumns.ModelOutput.value].to_dict()) - label = get_label_from_label_dict(df_by_hue[LoggingColumns.Label.value].to_dict()) - - return pd.DataFrame.from_dict({LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]], - LoggingColumns.ModelOutput.value: [model_output], - LoggingColumns.Label.value: [label]}) - - df = metrics_df.copy() - # Group by subject, and for each subject, convert the CVX03vs12, CVX0vs3 and CVX1vs2 predictions to CVX labels. - df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_CVX_labels).reset_index(drop=True) - df[LoggingColumns.Hue.value] = [MULTICLASS_HUE_NAME] * len(df) - return df - - -def get_labels_and_predictions_covid_labels(csv: Path) -> LabelsAndPredictions: - metrics_df = pd.read_csv(csv) - df = get_dataframe_with_covid_labels(metrics_df=metrics_df) - return get_labels_and_predictions_from_dataframe(df) - - -def print_metrics_from_csv(csv_to_set_optimal_threshold: Path, - csv_to_compute_metrics: Path, - config: ScalarModelBase, - is_crossval_report: bool) -> None: - assert config.target_names == TARGET_NAMES - - predictions_to_compute_metrics = get_labels_and_predictions_covid_labels( - csv=csv_to_compute_metrics) - - acc = accuracy_score(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs) - rows = [[f"{acc:.4f}"]] - print_table(rows, header=["Multiclass Accuracy"]) - - conf_matrix = confusion_matrix(predictions_to_compute_metrics.labels, predictions_to_compute_metrics.model_outputs) - rows = [] - header = ["", "CVX0 predicted", "CVX1 predicted", "CVX2 predicted", "CVX3 predicted"] - for i in range(conf_matrix.shape[0]): - line = [f"CVX{i} GT"] + list(conf_matrix[i]) - rows.append(line) - print_table(rows, header=header) diff --git a/InnerEye/ML/configs/ssl/CovidContainers.py b/InnerEye/ML/configs/ssl/CovidContainers.py index 2941b1b39..92f3521e8 100644 --- a/InnerEye/ML/configs/ssl/CovidContainers.py +++ b/InnerEye/ML/configs/ssl/CovidContainers.py @@ -33,4 +33,5 @@ def __init__(self, linear_head_augmentation_config=path_linear_head_augmentation_cxr, online_evaluator_lr=1e-5, linear_head_batch_size=64, + pl_find_unused_parameters=True, **kwargs) diff --git a/InnerEye/ML/lightning_models.py b/InnerEye/ML/lightning_models.py index 367aed834..f8ed4fd97 100644 --- a/InnerEye/ML/lightning_models.py +++ b/InnerEye/ML/lightning_models.py @@ -2,11 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List import torch from pytorch_lightning.utilities import move_data_to_device -from torch.nn import ModuleDict, ModuleList from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, TRAIN_PREFIX, VALIDATION_PREFIX @@ -15,19 +14,16 @@ from InnerEye.ML.dataset.sample import CroppedSample from InnerEye.ML.dataset.scalar_sample import ScalarItem from InnerEye.ML.lightning_base import InnerEyeLightning -from InnerEye.ML.lightning_metrics import Accuracy05, AccuracyAtOptimalThreshold, AreaUnderPrecisionRecallCurve, \ - AreaUnderRocCurve, BinaryCrossEntropyWithLogits, ExplainedVariance, FalseNegativeRateOptimalThreshold, \ - FalsePositiveRateOptimalThreshold, MeanAbsoluteError, MeanSquaredError, MetricForMultipleStructures, \ - OptimalThreshold, ScalarMetricsBase +from InnerEye.ML.lightning_metrics import MetricForMultipleStructures from InnerEye.ML.metrics import compute_dice_across_patches -from InnerEye.ML.metrics_dict import DataframeLogger, MetricsDict, SequenceMetricsDict +from InnerEye.ML.metrics_dict import DataframeLogger, MetricsDict from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.scalar_config import ScalarModelBase from InnerEye.ML.sequence_config import SequenceModelBase from InnerEye.ML.utils import image_util, metrics_util, model_util from InnerEye.ML.utils.dataset_util import DatasetExample, store_and_upload_example from InnerEye.ML.utils.model_util import get_scalar_model_inputs_and_labels -from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss, get_masked_model_outputs_and_labels +from InnerEye.ML.utils.sequence_utils import apply_sequence_model_loss from pytorch_lightning import Trainer SUBJECT_OUTPUT_PER_RANK_PREFIX = f"{SUBJECT_METRICS_FILE_NAME}.rank" @@ -190,16 +186,14 @@ def __init__(self, config: ScalarModelBase, *args: Any, **kwargs: Any) -> None: super().__init__(config, *args, **kwargs) self.model = config.create_model() raw_loss = model_util.create_scalar_loss_function(config) - self.posthoc_label_transform = config.get_posthoc_label_transform() if isinstance(config, SequenceModelBase): self.loss_fn = lambda model_output, loss: apply_sequence_model_loss(raw_loss, model_output, loss) self.target_indices = config.get_target_indices() - self.target_names = [SequenceMetricsDict.get_hue_name_from_target_index(p) - for p in config.sequence_target_positions] else: self.loss_fn = raw_loss self.target_indices = [] - self.target_names = config.target_names + + self.target_names = config.target_names self.is_classification_model = config.is_classification_model self.use_mean_teacher_model = config.compute_mean_teacher_model self.is_binary_classification_or_regression = True if len(config.class_names) == 1 else False @@ -207,42 +201,15 @@ def __init__(self, config: ScalarModelBase, *args: Any, **kwargs: Any) -> None: self.loss_type = config.loss_type # These two fields store the PyTorch Lightning Metrics objects that will compute metrics on validation # and training set, in particular ones that are not possible to compute from a single minibatch (AUC and alike) - self.train_metric_computers = self.create_metric_computers() - self.val_metric_computers = self.create_metric_computers() - + self.train_metric_computers = config.create_metric_computers() + self.val_metric_computers = config.create_metric_computers() + self.compute_and_log_metrics = config.compute_and_log_metrics # if config.compute_grad_cam: # model_to_evaluate = self.train_val_params.mean_teacher_model if \ # config.compute_mean_teacher_model else self.train_val_params.model # self.guided_grad_cam = VisualizationMaps(model_to_evaluate, config) # config.visualization_folder.mkdir(exist_ok=True) - def create_metric_computers(self) -> ModuleDict: - """ - Gets a set of objects that compute all the metrics for the type of model that is being trained, - across all prediction targets (sequence positions when using a sequence model). - :return: A dictionary mapping from names of prediction targets to a list of metric computers. - """ - # The metric computers should be stored in an object that derives from torch.Module, - # so that they are picked up when moving the whole LightningModule to GPU. - # https://github.com/PyTorchLightning/pytorch-lightning/issues/4713 - return ModuleDict({p: self._get_metrics_computers() for p in self.target_names}) - - def _get_metrics_computers(self) -> ModuleList: - """ - Gets the objects that compute metrics for the present kind of models, for a single prediction target. - """ - if self.is_classification_model: - return ModuleList([Accuracy05(), - AccuracyAtOptimalThreshold(), - OptimalThreshold(), - FalsePositiveRateOptimalThreshold(), - FalseNegativeRateOptimalThreshold(), - AreaUnderRocCurve(), - AreaUnderPrecisionRecallCurve(), - BinaryCrossEntropyWithLogits()]) - else: - return ModuleList([MeanAbsoluteError(), MeanSquaredError(), ExplainedVariance()]) - def forward(self, *model_inputs: torch.Tensor) -> torch.Tensor: # type: ignore """ Runs a list of model input tensors through the model and returns the results. @@ -283,7 +250,6 @@ def training_or_validation_step(self, """ model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, self.target_indices, sample) labels = model_inputs_and_labels.labels - labels = self.posthoc_label_transform(labels) if is_training: logits = self.model(*model_inputs_and_labels.model_inputs) else: @@ -292,62 +258,23 @@ def training_or_validation_step(self, subject_ids = model_inputs_and_labels.subject_ids loss = self.loss_fn(logits, labels) self.write_loss(is_training, loss) - self.compute_and_log_metrics(logits, labels, subject_ids, is_training) + metrics = self.train_metric_computers if is_training else self.val_metric_computers + logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore + data_split = ModelExecutionMode.TRAIN if is_training else ModelExecutionMode.VAL + self.compute_and_log_metrics(logits=logits, + targets=labels, + subject_ids=subject_ids, + is_training=is_training, + metrics=metrics, + logger=logger, + current_epoch=self.current_epoch, + data_split=data_split) self.log_on_epoch(name=MetricType.SUBJECT_COUNT, value=len(model_inputs_and_labels.subject_ids), is_training=is_training, reduce_fx=sum) return loss - def compute_and_log_metrics(self, - logits: torch.Tensor, - targets: torch.Tensor, - subject_ids: List[str], - is_training: bool) -> None: - """ - Computes all the metrics for a given (logits, labels) pair, and writes them to the loggers. - :param logits: The model output before normalization. - :param targets: The expected model outputs. - :param subject_ids: The subject IDs for the present minibatch. - :param is_training: If True, write the metrics as training metrics, otherwise as validation metrics. - :return: - """ - metrics = self.train_metric_computers if is_training else self.val_metric_computers - per_subject_outputs: List[Tuple[str, str, torch.Tensor, torch.Tensor]] = [] - for i, (prediction_target, metric_list) in enumerate(metrics.items()): - # mask the model outputs and labels if required - masked = get_masked_model_outputs_and_labels( - logits[:, i, ...], targets[:, i, ...], subject_ids) - # compute metrics on valid masked tensors only - if masked is not None: - _logits = masked.model_outputs.data - _posteriors = self.logits_to_posterior(_logits) - # Classification metrics expect labels as integers, but they are float throughout the rest of the code - labels_dtype = torch.int if self.is_classification_model else _posteriors.dtype - _labels = masked.labels.data.to(dtype=labels_dtype) - _subject_ids = masked.subject_ids - assert _subject_ids is not None - for metric in metric_list: - if isinstance(metric, ScalarMetricsBase) and metric.compute_from_logits: - metric(_logits, _labels) - else: - metric(_posteriors, _labels) - per_subject_outputs.extend( - zip(_subject_ids, [prediction_target] * len(_subject_ids), _posteriors.tolist(), _labels.tolist())) - # Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current - # rank in distributed training, and will be aggregated after training. - logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore - data_split = ModelExecutionMode.TRAIN if is_training else ModelExecutionMode.VAL - for subject, prediction_target, model_output, label in per_subject_outputs: - logger.add_record({ - LoggingColumns.Epoch.value: self.current_epoch, - LoggingColumns.Patient.value: subject, - LoggingColumns.Hue.value: prediction_target, - LoggingColumns.ModelOutput.value: model_output, - LoggingColumns.Label.value: label, - LoggingColumns.DataSplit.value: data_split.value - }) - def training_or_validation_epoch_end(self, is_training: bool) -> None: """ Writes all training or validation metrics that were aggregated over the epoch to the loggers. diff --git a/InnerEye/ML/metrics_dict.py b/InnerEye/ML/metrics_dict.py index 96617e1ef..fd2dc8fad 100644 --- a/InnerEye/ML/metrics_dict.py +++ b/InnerEye/ML/metrics_dict.py @@ -21,7 +21,6 @@ from InnerEye.Common.metrics_constants import INTERNAL_TO_LOGGING_COLUMN_NAMES, LoggingColumns, MetricType, \ MetricTypeOrStr, SEQUENCE_POSITION_HUE_NAME_PREFIX from InnerEye.ML.common import ModelExecutionMode -from InnerEye.ML.scalar_config import DEFAULT_KEY from InnerEye.ML.utils.metrics_util import binary_classification_accuracy, mean_absolute_error, \ mean_squared_error, r2_score @@ -29,6 +28,8 @@ T = TypeVar('T', np.ndarray, float) MetricsPerExecutionModeAndEpoch = Dict[ModelExecutionMode, Dict[Union[int, str], 'ScalarMetricsDict']] +DEFAULT_KEY = "Default" + def average_metric_values(values: List[float], skip_nan_when_averaging: bool) -> float: """ diff --git a/InnerEye/ML/model_testing.py b/InnerEye/ML/model_testing.py index e75e5107a..617d3f346 100644 --- a/InnerEye/ML/model_testing.py +++ b/InnerEye/ML/model_testing.py @@ -420,8 +420,6 @@ def classification_model_test(config: ScalarModelBase, :param model_proc: whether we are testing an ensemble or single model :return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs. """ - posthoc_label_transform = config.get_posthoc_label_transform() - pipeline = create_inference_pipeline(config=config, checkpoint_paths=checkpoint_paths) if pipeline is None: @@ -450,7 +448,6 @@ def classification_model_test(config: ScalarModelBase, result = pipeline.predict(sample) model_output = result.posteriors label = result.labels.to(device=model_output.device) - label = posthoc_label_transform(label) sample_id = result.subject_ids[0] if output_logger: for i in range(len(config.target_names)): diff --git a/InnerEye/ML/reports/notebook_report.py b/InnerEye/ML/reports/notebook_report.py index 5a7e49c22..cbb122c85 100644 --- a/InnerEye/ML/reports/notebook_report.py +++ b/InnerEye/ML/reports/notebook_report.py @@ -170,7 +170,7 @@ def generate_classification_crossval_notebook(result_notebook: Path, 'innereye_path': str(fixed_paths.repository_root_directory()), 'train_metrics_csv': "", 'val_metrics_csv': str_or_empty(crossval_metrics), - 'test_metrics_csv': str_or_empty(crossval_metrics), + 'test_metrics_csv': "", "config": codecs.encode(pickle.dumps(config), "base64").decode(), "is_crossval_report": True } diff --git a/InnerEye/ML/scalar_config.py b/InnerEye/ML/scalar_config.py index bde92dcc9..b156b94e4 100644 --- a/InnerEye/ML/scalar_config.py +++ b/InnerEye/ML/scalar_config.py @@ -8,20 +8,27 @@ import pandas as pd import param +import torch from azureml.core import ScriptRunConfig from azureml.train.hyperdrive import HyperDriveConfig +from torch.nn import ModuleDict, ModuleList + from InnerEye.Common.common_util import print_exception from InnerEye.Common.generic_parsing import ListOrDictParam +from InnerEye.Common.metrics_constants import LoggingColumns from InnerEye.Common.type_annotations import TupleInt3 from InnerEye.ML.common import ModelExecutionMode, OneHotEncoderBase from InnerEye.ML.deep_learning_config import ModelCategory +from InnerEye.ML.lightning_metrics import Accuracy05, AccuracyAtOptimalThreshold, AreaUnderPrecisionRecallCurve, \ + AreaUnderRocCurve, BinaryCrossEntropyWithLogits, ExplainedVariance, FalseNegativeRateOptimalThreshold, \ + FalsePositiveRateOptimalThreshold, MeanAbsoluteError, MeanSquaredError, OptimalThreshold, ScalarMetricsBase +from InnerEye.ML.metrics_dict import DEFAULT_KEY, DataframeLogger from InnerEye.ML.model_config_base import ModelConfigBase, ModelTransformsPerExecutionMode from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER from InnerEye.ML.utils.split_dataset import DatasetSplits - -DEFAULT_KEY = "Default" +from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels class AggregationType(Enum): @@ -125,7 +132,9 @@ class ScalarModelBase(ModelConfigBase): "reporting results. If provided, the length of this list must match the " "number of model outputs (and of transformed labels, if defined; see " "get_posthoc_label_transform()). By default, this inherits the value of " - "class_names at initialisation.") + "class_names at initialisation. This will be ignored in sequence models, " + "as target_names are determined automatically based on" + "sequence_target_positions") aggregation_type: AggregationType = param.ClassSelector(default=AggregationType.Average, class_=AggregationType, doc="The type of global pooling aggregation to use between" " the encoder and the classifier.") @@ -362,14 +371,6 @@ def get_label_transform(self) -> Union[Callable, List[Callable]]: """ return LabelTransformation.identity - def get_posthoc_label_transform(self) -> Callable: - """ - Return a transformation to apply to the labels after they are loaded, for computing losses, metrics, and - reports. The transformed labels refer to the config's target_names, if defined (class_names, otherwise). - If not overriden, this method does not change the loaded labels. - """ - return lambda x: x # no-op by default - def read_dataset_into_dataframe_and_pre_process(self) -> None: assert self.local_dataset is not None file_path = self.local_dataset / self.dataset_csv @@ -503,6 +504,87 @@ def get_scalar_item_transform(self) -> ModelTransformsPerExecutionMode: val=ScalarItemAugmentation(image_transform.val, segmentation_transform.val), test=ScalarItemAugmentation(image_transform.test, segmentation_transform.test)) + def create_metric_computers(self) -> ModuleDict: + """ + Gets a set of objects that compute all the metrics for the type of model that is being trained, + across all prediction targets (sequence positions when using a sequence model). + :return: A dictionary mapping from names of prediction targets to a list of metric computers. + """ + # The metric computers should be stored in an object that derives from torch.Module, + # so that they are picked up when moving the whole LightningModule to GPU. + # https://github.com/PyTorchLightning/pytorch-lightning/issues/4713 + return ModuleDict({p: self._get_metrics_computers() for p in self.target_names}) + + def _get_metrics_computers(self) -> ModuleList: + """ + Gets the objects that compute metrics for the present kind of models, for a single prediction target. + """ + if self.is_classification_model: + return ModuleList([Accuracy05(), + AccuracyAtOptimalThreshold(), + OptimalThreshold(), + FalsePositiveRateOptimalThreshold(), + FalseNegativeRateOptimalThreshold(), + AreaUnderRocCurve(), + AreaUnderPrecisionRecallCurve(), + BinaryCrossEntropyWithLogits()]) + else: + return ModuleList([MeanAbsoluteError(), MeanSquaredError(), ExplainedVariance()]) + + def compute_and_log_metrics(self, + logits: torch.Tensor, + targets: torch.Tensor, + subject_ids: List[str], + is_training: bool, + metrics: ModuleDict, + logger: DataframeLogger, + current_epoch: int, + data_split: ModelExecutionMode) -> None: + """ + Computes all the metrics for a given (logits, labels) pair, and writes them to the loggers. + :param logits: The model output before normalization. + :param targets: The expected model outputs. + :param subject_ids: The subject IDs for the present minibatch. + :param is_training: If True, write the metrics as training metrics, otherwise as validation metrics. + :param metrics: A dictionary mapping from names of prediction targets to a list of metric computers, + as returned by create_metric_computers. + :param logger: An object of type DataframeLogger which can be be used for logging within this function. + :param current_epoch: Current epoch number. + :param data_split: ModelExecutionMode object indicating if this is the train or validation split. + :return: + """ + per_subject_outputs: List[Tuple[str, str, torch.Tensor, torch.Tensor]] = [] + for i, (prediction_target, metric_list) in enumerate(metrics.items()): + # mask the model outputs and labels if required + masked = get_masked_model_outputs_and_labels( + logits[:, i, ...], targets[:, i, ...], subject_ids) + # compute metrics on valid masked tensors only + if masked is not None: + _logits = masked.model_outputs.data + _posteriors = self.get_post_loss_logits_normalization_function()(_logits) + # Classification metrics expect labels as integers, but they are float throughout the rest of the code + labels_dtype = torch.int if self.is_classification_model else _posteriors.dtype + _labels = masked.labels.data.to(dtype=labels_dtype) + _subject_ids = masked.subject_ids + assert _subject_ids is not None + for metric in metric_list: + if isinstance(metric, ScalarMetricsBase) and metric.compute_from_logits: + metric(_logits, _labels) + else: + metric(_posteriors, _labels) + per_subject_outputs.extend( + zip(_subject_ids, [prediction_target] * len(_subject_ids), _posteriors.tolist(), _labels.tolist())) + # Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current + # rank in distributed training, and will be aggregated after training. + for subject, prediction_target, model_output, label in per_subject_outputs: + logger.add_record({ + LoggingColumns.Epoch.value: current_epoch, + LoggingColumns.Patient.value: subject, + LoggingColumns.Hue.value: prediction_target, + LoggingColumns.ModelOutput.value: model_output, + LoggingColumns.Label.value: label, + LoggingColumns.DataSplit.value: data_split.value + }) def get_non_image_features_dict(default_channels: List[str], specific_channels: Optional[Dict[str, List[str]]] = None) -> Dict[str, List[str]]: diff --git a/InnerEye/ML/sequence_config.py b/InnerEye/ML/sequence_config.py index 7fad8d249..9122623e3 100644 --- a/InnerEye/ML/sequence_config.py +++ b/InnerEye/ML/sequence_config.py @@ -12,6 +12,7 @@ from InnerEye.Common.metrics_constants import LoggingColumns from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.deep_learning_config import TemperatureScalingConfig +from InnerEye.ML.metrics_dict import SequenceMetricsDict from InnerEye.ML.scalar_config import ScalarModelBase from InnerEye.ML.utils.split_dataset import DatasetSplits @@ -65,6 +66,10 @@ def __init__(self, **params: Any): logging.info(f"Temperature scaling will be performed on the " f"validation set using the config: {self.temperature_scaling_config}") + def validate(self) -> None: + self.target_names = [SequenceMetricsDict.get_hue_name_from_target_index(p) + for p in self.sequence_target_positions] + def get_target_indices(self) -> List[int]: """ Computes the zero based array indices inside of a sequence of items diff --git a/Tests/ML/configs/utils/test_hierarchical_covid_model_report.py b/Tests/ML/configs/utils/test_hierarchical_covid_model_report.py deleted file mode 100644 index 1dea1bff5..000000000 --- a/Tests/ML/configs/utils/test_hierarchical_covid_model_report.py +++ /dev/null @@ -1,22 +0,0 @@ -import pandas as pd -from math import nan - -from InnerEye.Common.metrics_constants import LoggingColumns -from InnerEye.ML.configs.reports.covid_hierarchical_model_report import MULTICLASS_HUE_NAME, \ - get_dataframe_with_covid_labels - - -def test_get_dataframe_with_covid_labels() -> None: - - df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], - LoggingColumns.Hue.value: ['CVX03vs12', 'CVX0vs3', 'CVX1vs2'] * 4, - LoggingColumns.Label.value: [0, 0, nan, 0, 1, nan, 1, nan, 0, 1, nan, 1], - LoggingColumns.ModelOutput.value: [0.1, 0.1, 0.5, 0.1, 0.9, 0.5, 0.9, 0.9, 0.9, 0.1, 0.2, 0.1]}) - expected_df = pd.DataFrame.from_dict({LoggingColumns.Patient.value: [1, 2, 3, 4], - LoggingColumns.ModelOutput.value: [0, 3, 2, 0], - LoggingColumns.Label.value: [0, 3, 1, 2], - LoggingColumns.Hue.value: [MULTICLASS_HUE_NAME] * 4 - }) - - multiclass_df = get_dataframe_with_covid_labels(df) - assert expected_df.equals(multiclass_df) diff --git a/Tests/ML/models/architectures/sequential/test_rnn_classifier.py b/Tests/ML/models/architectures/sequential/test_rnn_classifier.py index f49c3cfd4..6febbc15c 100644 --- a/Tests/ML/models/architectures/sequential/test_rnn_classifier.py +++ b/Tests/ML/models/architectures/sequential/test_rnn_classifier.py @@ -437,7 +437,7 @@ def test_run_ml_with_multi_label_sequence_model(test_output_dirs: OutputFolderFo when it is started via run_ml. """ logging_to_stdout() - config = ToyMultiLabelSequenceModel(should_validate=False) + config = ToyMultiLabelSequenceModel() assert config.get_target_indices() == [1, 2, 3] expected_prediction_targets = [f"{SEQUENCE_POSITION_HUE_NAME_PREFIX} {x}" for x in ["01", "02", "03"]] diff --git a/Tests/ML/test_metrics.py b/Tests/ML/test_metrics.py index 22230a8cf..701374a6b 100644 --- a/Tests/ML/test_metrics.py +++ b/Tests/ML/test_metrics.py @@ -18,7 +18,6 @@ from InnerEye.ML.configs.classification.DummyClassification import DummyClassification from InnerEye.ML.configs.regression.DummyRegression import DummyRegression from InnerEye.ML.lightning_metrics import AverageWithoutNan, MetricForMultipleStructures, ScalarMetricsBase -from InnerEye.ML.lightning_models import ScalarLightning from InnerEye.ML.metrics_dict import MetricsDict, get_column_name_for_logging @@ -164,8 +163,8 @@ def test_get_column_name_for_logging() -> None: def test_classification_metrics() -> None: - classification_module = ScalarLightning(DummyClassification()) - metrics = classification_module._get_metrics_computers() + config = DummyClassification() + metrics = config._get_metrics_computers() logits = [torch.tensor([2.1972, 1.3863, 0.4055]), torch.tensor([-0.8473, 2.1972, -0.4055])] posteriors = [torch.sigmoid(logit) for logit in logits] labels = [torch.tensor([1, 1, 0]), torch.tensor([0, 0, 0])] @@ -203,8 +202,8 @@ def test_classification_metrics() -> None: def test_regression_metrics() -> None: - regression_module = ScalarLightning(DummyRegression()) - metrics = regression_module._get_metrics_computers() + config = DummyRegression() + metrics = config._get_metrics_computers() outputs = [torch.tensor([1., 2., 1.]), torch.tensor([4., 0., 2.])] labels = [torch.tensor([1., 1., 0.]), torch.tensor([2., 0., 2.])] for output, label in zip(outputs, labels):