Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Update Covid configs #526

Merged
merged 24 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ via the `--inference_on_val_set` flag.
gets uploaded to AzureML, by skipping all test folders.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameter `extra_downloaded_run_id` has been
renamed to `pretraining_run_checkpoints`.
- ([#526](https://github.com/microsoft/InnerEye-DeepLearning/pull/526)) Updated Covid config to use a multiclass
formulation. Moved functions `create_metric_computers` and `compute_and_log_metrics` from `ScalarLightning` to
`ScalarModelBase`.

### Fixed
- ([#537](https://github.com/microsoft/InnerEye-DeepLearning/pull/537)) Print warning if inference is disabled but comparison requested.
Expand All @@ -54,6 +57,9 @@ in inference-only runs when using lightning containers.
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and
`weights_url` can no longer be used to initialize a training run, only inference runs.
- ([#526](https://github.com/microsoft/InnerEye-DeepLearning/pull/526)) Removed `get_posthoc_label_transform` in
class `ScalarModelBase`. Instead, functions `get_loss_function` and `compute_and_log_metrics` in
`ScalarModelBase` can be implemented to compute the loss and metrics in a task-specific manner.

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import codecs
import logging
import pickle
import random
import math
from pathlib import Path

from typing import Any, Callable
from typing import Any, Callable, List

import PIL
import numpy as np
import pandas as pd
import param
import torch

from PIL import Image
from torch.nn import ModuleList, ModuleDict
from pytorch_lightning import LightningModule
from torchvision.transforms import Compose

from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path
from InnerEye.Common.metrics_constants import LoggingColumns

from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName

Expand All @@ -28,29 +30,27 @@
from InnerEye.ML.deep_learning_config import LRSchedulerType, MultiprocessingStartMethod, \
OptimizerType

from InnerEye.ML.lightning_metrics import Accuracy05
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
from InnerEye.ML.model_testing import MODEL_OUTPUT_CSV

from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImagingFeatureType
from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty

from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.utils.run_recovery import RunRecovery
from InnerEye.ML.utils.split_dataset import DatasetSplits

from InnerEye.ML.configs.ssl.CovidContainers import COVID_DATASET_ID
from InnerEye.Common import fixed_paths as fixed_paths_innereye
from InnerEye.ML.metrics_dict import MetricsDict, DataframeLogger


class CovidHierarchicalModel(ScalarModelBase):
class CovidModel(ScalarModelBase):
"""
Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model.

For AML you need to provide the run_id of your SSL training job as a command line argument
--pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your
machine and load the corresponding pretrained model.

To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use hte
To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use the
--name_of_checkpoint argument.
"""
use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL."
Expand All @@ -64,8 +64,7 @@ class CovidHierarchicalModel(ScalarModelBase):
"is assumed to contain unique ids.")

def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any):
super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'],
loss_type=ScalarLoss.CustomClassification,
super().__init__(loss_type=ScalarLoss.CustomClassification,
class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'],
max_num_gpus=1,
azure_dataset_id=covid_dataset_id,
Expand All @@ -84,7 +83,7 @@ def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any):
l_rate_step_gamma=1.0,
l_rate_multi_step_milestones=None,
should_validate=False) # validate only after adding kwargs
self.num_classes = 3
self.num_classes = 4
self.add_and_validate(kwargs)

def validate(self) -> None:
Expand Down Expand Up @@ -192,39 +191,53 @@ def _get_ssl_checkpoint_path(self) -> Path:
def pre_process_dataset_dataframe(self) -> None:
pass

@staticmethod
def get_posthoc_label_transform() -> Callable:
import torch

def multiclass_to_hierarchical_labels(classes: torch.Tensor) -> torch.Tensor:
classes = classes.clone()
cvx03vs12 = classes[..., 1] + classes[..., 2]
cvx0vs3 = classes[..., 3]
cvx1vs2 = classes[..., 2]
cvx0vs3[cvx03vs12 == 1] = float('nan') # CVX0vs3 only gets gradient for CVX03
cvx1vs2[cvx03vs12 == 0] = float('nan') # CVX1vs2 only gets gradient for CVX12
return torch.stack([cvx03vs12, cvx0vs3, cvx1vs2], -1)

return multiclass_to_hierarchical_labels

@staticmethod
def get_loss_function() -> Callable:
import torch
import torch.nn.functional as F

def nan_bce_with_logits(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Compute BCE with logits, ignoring NaN values"""
valid = labels.isfinite()
losses = F.binary_cross_entropy_with_logits(output[valid], labels[valid], reduction='none')
return losses.sum() / labels.shape[0]

return nan_bce_with_logits
def custom_loss(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
labels = torch.argmax(labels, dim=-1)
return F.cross_entropy(input=output, target=labels, reduction="sum")

return custom_loss

def get_post_loss_logits_normalization_function(self) -> Callable:
return torch.nn.Softmax()

def create_metric_computers(self) -> ModuleDict:
return ModuleDict({MetricsDict.DEFAULT_HUE_KEY: ModuleList([Accuracy05()])})

def compute_and_log_metrics(self,
logits: torch.Tensor,
targets: torch.Tensor,
subject_ids: List[str],
is_training: bool,
metrics: ModuleDict,
logger: DataframeLogger,
current_epoch: int,
data_split: ModelExecutionMode) -> None:
posteriors = self.get_post_loss_logits_normalization_function()(logits)
labels = torch.argmax(targets, dim=-1)
metric = metrics[MetricsDict.DEFAULT_HUE_KEY][0]
metric(posteriors, labels)

per_subject_outputs = zip(subject_ids, posteriors.tolist(), targets.tolist())
for subject, model_output, target in per_subject_outputs:
for i in range(len(self.target_names)):
logger.add_record({
LoggingColumns.Epoch.value: current_epoch,
LoggingColumns.Patient.value: subject,
LoggingColumns.Hue.value: self.target_names[i],
LoggingColumns.ModelOutput.value: model_output[i],
LoggingColumns.Label.value: target[i],
LoggingColumns.DataSplit.value: data_split.value
})

def generate_custom_report(self, report_dir: Path, model_proc: ModelProcessing) -> Path:
"""
Generate a custom report for the CovidDataset Hierarchical model. At the moment, this report will read the
file model_output.csv generated for the training, validation or test sets and compute a 4 class accuracy
and confusion matrix based on this.
Generate a custom report for the Covid model. This report will read the file model_output.csv generated for
the training, validation or test sets and compute the multiclass accuracy based on this.
:param report_dir: Directory report is to be written to
:param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different
paths for single vs ensemble runs.)
Expand All @@ -234,24 +247,37 @@ def get_output_csv_path(mode: ModelExecutionMode) -> Path:
p = get_best_epoch_results_path(mode=mode, model_proc=model_proc)
return self.outputs_folder / p / MODEL_OUTPUT_CSV

def get_labels_and_predictions(df: pd.DataFrame) -> pd.DataFrame:
labels = []
predictions = []
for target in self.target_names:
target_df = df[df[LoggingColumns.Hue.value] == target]
predictions.append(target_df[LoggingColumns.ModelOutput.value])
labels.append(target_df[LoggingColumns.Label.value])

return pd.DataFrame.from_dict({LoggingColumns.Patient.value: [df.iloc[0][LoggingColumns.Patient.value]],
LoggingColumns.ModelOutput.value: [np.argmax(predictions)],
LoggingColumns.Label.value: [np.argmax(labels)]})

def get_accuracy(df: pd.DataFrame) -> float:
df = df.groupby(LoggingColumns.Patient.value, as_index=False).apply(get_labels_and_predictions).reset_index(
drop=True)
return (df[LoggingColumns.ModelOutput.value] == df[LoggingColumns.Label.value]).mean() # type: ignore

train_metrics = get_output_csv_path(ModelExecutionMode.TRAIN)
val_metrics = get_output_csv_path(ModelExecutionMode.VAL)
test_metrics = get_output_csv_path(ModelExecutionMode.TEST)

notebook_params = \
{
'innereye_path': str(fixed_paths_innereye.repository_root_directory()),
'train_metrics_csv': str_or_empty(train_metrics),
'val_metrics_csv': str_or_empty(val_metrics),
'test_metrics_csv': str_or_empty(test_metrics),
"config": codecs.encode(pickle.dumps(self), "base64").decode(),
"is_crossval_report": False
}
template = Path(__file__).absolute().parent.parent / "reports" / "CovidHierarchicalModelReport.ipynb"
return generate_notebook(template,
notebook_params=notebook_params,
result_notebook=report_dir / get_ipynb_report_name(
f"{self.model_category.value}_hierarchical"))
msg = f"Multiclass Accuracy Train: {get_accuracy(pd.read_csv(train_metrics))}\n" if train_metrics.exists() else ""
msg += f"Multiclass Accuracy Val: {get_accuracy(pd.read_csv(val_metrics))}\n" if val_metrics.exists() else ""
msg += f"Multiclass Accuracy Test: {get_accuracy(pd.read_csv(test_metrics))}\n" if test_metrics.exists() else ""

report = report_dir / "report.txt"
report.write_text(msg)

logging.info(msg)

return report


class DicomPreparation:
Expand Down
160 changes: 0 additions & 160 deletions InnerEye/ML/configs/reports/CovidHierarchicalModelReport.ipynb

This file was deleted.

Loading