Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Adding Active Label Cleaning code #559

Merged
merged 23 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,10 @@ Tests/ML/test_outputs
# This file contains the run recovery ID of the most recent job
most_recent_run.txt
# The default folder that contains downloaded Tensorboard files
tensorboard_runs
tensorboard_runs

# InnerEye-DataQuality
InnerEye-DataQuality/cifar-10-python.tar.gz
InnerEye-DataQuality/name_stats_scoring.png
InnerEye-DataQuality/cifar-10-batches-py
InnerEye-DataQuality/logs
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs that run in AzureML.
ensemble) using the parameter `model_id`.
- ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Added a parameter `pretraining_dataset_id` to
`NIH_COVID_BYOL` to specify the name of the SSL training dataset.
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.

### Changed
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
Expand Down
5 changes: 5 additions & 0 deletions InnerEye-DataQuality/InnerEyeDataQuality/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

115 changes: 115 additions & 0 deletions InnerEye-DataQuality/InnerEyeDataQuality/algorithms/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

import logging
import time
from dataclasses import dataclass
from typing import Any, Optional

import numpy as np
import scipy
from scipy.special import softmax
from sklearn.neighbors import kneighbors_graph


@dataclass(init=True, frozen=True)
class GraphParameters:
"""Class for setting graph connectivity and diffusion parameters."""
n_neighbors: int
diffusion_alpha: float
cg_solver_max_iter: int
diffusion_batch_size: Optional[int]
distance_kernel: str # {'euclidean' or 'cosine'}


def _get_affinity_matrix(embeddings: np.ndarray,
n_neighbors: int,
distance_kernel: str = 'cosine') -> scipy.sparse.csr.csr_matrix:
"""
:param embeddings: Input sample embeddings (n_samples x n_embedding_dim)
:param n_neighbors: Number of neighbors in the KNN Graph
:param distance_kernel: Distance kernel to compute sample similarities {'euclidean' or 'cosine'}
"""

# Build a k-NN graph using the embeddings
if distance_kernel == 'euclidean':
sigma = embeddings.shape[1]
knn_dist_graph = kneighbors_graph(embeddings, n_neighbors, mode='distance', metric='euclidean', n_jobs=-1)
knn_dist_graph.data = np.exp(-1.0 * np.asarray(knn_dist_graph.data) ** 2 / (2.0 * sigma ** 2))
elif distance_kernel == 'cosine':
knn_dist_graph = kneighbors_graph(embeddings, n_neighbors, mode='distance', metric='cosine', n_jobs=-1)
knn_dist_graph.data = 1.0 - np.asarray(knn_dist_graph.data) / 2.0
else:
raise ValueError(f"Unknown sample distance kernel {distance_kernel}")

return knn_dist_graph


def build_connectivity_graph(normalised: bool = True, **affinity_kwargs: Any) -> np.ndarray:
"""
Builds connectivity graph and returns adjacency and degree matrix
:param normalised: If set to True, graph adjacency is normalised with the norm of degree matrix
:param affinity_kwargs: Arguments required to construct an affinity matrix
(weights representing similarity between points)
"""

# Build a symmetric adjacency matrix
A = _get_affinity_matrix(**affinity_kwargs)
W = 0.5 * (A + A.T)
if normalised:
# Normalize the similarity graph
W = W - scipy.sparse.diags(W.diagonal())
D = W.sum(axis=1)
D[D == 0] = 1
D_sqrt_inv = np.array(1. / np.sqrt(D))
D_sqrt_inv = scipy.sparse.diags(D_sqrt_inv.reshape(-1))
L_norm = D_sqrt_inv * W * D_sqrt_inv
return L_norm
else:
num_samples = W.shape[0]
D = W.sum(axis=1)
D = np.diag(np.asarray(D).reshape(num_samples, ))
L = D - W
return L


def label_diffusion(inv_laplacian: np.ndarray,
labels: np.ndarray,
query_batch_ids: np.ndarray,
class_priors: Optional[np.ndarray] = None,
diffusion_normalizing_factor: float = 0.01) -> np.ndarray:
"""
:param laplacian_inv: inverse laplacian of the graph
:param labels:
:param query_batch_ids: the ids of the "labeled" samples
:param class_priors: prior distribution of each class [n_classes,]
:param diffusion_normalizing_factor: factor to normalize the diffused labels
"""
diffusion_start = time.time()

# Input number of nodes and classes
n_samples = labels.shape[0]
n_classes = labels.shape[1]

# Find the labelled set of nodes in the graph
all_idx = np.array(range(n_samples))
labeled_idx = np.setdiff1d(all_idx, query_batch_ids.flatten())
assert (np.all(np.isin(query_batch_ids, all_idx)))
assert (np.allclose(np.sum(labels, axis=1), np.ones(shape=(labels.shape[0])), rtol=1e-3))

# Initialize the y vector for each class (eq 5 from the paper, normalized with the class size)
# and apply label propagation
y = np.zeros((n_samples, n_classes))
y[labeled_idx] = labels[labeled_idx] / np.sum(labels[labeled_idx], axis=0, keepdims=True)
if class_priors is not None:
y = y * class_priors
Z = np.matmul(inv_laplacian[query_batch_ids, :], y)

# Normalise the diffused logits
output = softmax(Z / diffusion_normalizing_factor, axis=1)
# output = Z / Z.sum(axis=1)
logging.debug(f"Graph diffusion time: {0: .2f} secs".format(time.time() - diffusion_start))

return output
60 changes: 60 additions & 0 deletions InnerEye-DataQuality/InnerEyeDataQuality/configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
## Config arguments for model training:
All possible config arguments are defined in [model_config.py](InnerEyeDataQuality/configs/models/model_config.py). Here you will find a summary of the most important config arguments:
* If you want to train a model with co_teaching, you will need to set `train.use_co_teaching: True` in your config.
* If you want to finetune from a pretrained SSL checkpoint:
* You will need to set `train.use_self_supervision: True` to tell the code to load a pretrained checkpoint.
* You will need update the `train.self_supervision.checkpoints: [PATH_TO_SSL]` field with the checkpoints to use for initialization of your model. Note that if you want to train a co-teaching model in combination with SSL pretrained initialization your list of checkpoint needs to be of length 2.
* You can also choose whether to freeze the encoder or not during finetuning with `train.self_supervision.freeze_encoder` field.
* If you want to train a model using ELR, you can set `train.use_elr: True`

### CIFAR10H
We provide configurations to run experiments on CIFAR10H with resp. 15% and 30% noise rate.
* Configs for 15% noise rate experiments can be found in [configs/models/cifar10h_noise_15](InnerEyeDataQuality/configs/models/cifar10h_noise_15). In detail this folder contains configs for
* vanilla resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml)
* co-teaching resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml)
* SSL + linear head training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml)
* Configs for 30% noise rate experiments can be found in [configs/models/cifar10h_noise_30](InnerEyeDataQuality/configs/models/cifar10h_noise_30). In detail this folder contains configs for:
* vanilla resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml)
* co-teaching resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml)
* SSL + linear head training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml)
* ELR training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_elr.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_elr.yaml)
* Examples of configs for models used in the model selection benchmark experiment can be found in the [configs/models/benchmark_3_idn](InnerEyeDataQuality/configs/models/benchmark_3_idn)

### Noisy Chest-Xray
*Note:* To run any model on this dataset, you will need to first make sure you have the dataset ready onto your machine (see Benchmark datasets > Noisy Chest-Xray > Pre-requisite section).

* With provide configurations corresponding to the experiments on the NoisyChestXray dataset with 13% noise rate in the [configs/models/cxr](InnerEyeDataQuality/configs/models/cxr) folder:
* Vanilla resnet training: [InnerEyeDataQuality/configs/models/cxr/resnet.yaml](InnerEyeDataQuality/configs/models/cxr/resnet.yaml)
* Co-teaching resnet training: [InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml](InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml)
* Co-teaching from a pretrained SSL checkpoint: [InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml]([InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml])
<br/><br/>

## Config arguments for label cleaning simulation:

#### More details about the selector config
Here is an example of a selector config, with details about each field:

* `selector:`
* `type`: Which selector to use. There are several options available:
* `PosteriorBasedSelectorJoint`: Using the ranking function proposed in the manuscript CE(posteriors, labels) - H(posteriors)
* `PosteriorBasedSelector`: Using CE(posteriors, labels) as the ranking function
* `GraphBasedSelector`: Graph based selection of the next samples based on the embeddings of each sample.
* `BaldSelector`: Selection of the next sample with the BALD objective.
* `model_name`: The name that will be used for the legend of the simulation plot
* `model_config_path`: Path to the config file used to train the selection model.
* `use_active_relabelling`: Whether to turn on the active component of the active learning framework. If set to True, the model will be retrained regularly during the selection process.
* `output_directory`: Optional field where can specify the output directory to store the results in.


#### Off-the-shelf simulation configs
* We provide the configs for various selectors for the CIFAR10H experiments in the [configs/selection/cifar10h_noise_15](InnerEyeDataQuality/configs/selection/cifar10h_noise_15) and in the [configs/selection/cifar10h_noise_30](InnerEyeDataQuality/configs/selection/cifar10h_noise_30) folders.
* And likewise for the NoisyChestXray dataset, you can find a set of selector configs in the [configs/selection/cxr](InnerEyeDataQuality/configs/selection/cxr) folder.
<br/><br/>

## Configs for self-supervised model training:

CIFAR10H: To pretrain embeddings with contrastive learning on CIFAR10H you can use the
[cifar10h_byol.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/cifar10h_byol.yaml) or the [cifar10h_simclr.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/cifar10h_simclr.yaml) config files.

Chest X-ray: Provided that you have downloaded the dataset as explained in the Benchmark Datasets > Other Chest Xray Datasets > NIH Datasets > Pre-requisites section, you can easily launch a unsupervised pretraining run on the full NIH dataset using the [nih_byol.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/nih_byol.yaml) or the [nih_simclr.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/nih_simclr.yaml)
configs. Don't forget to update the `dataset_dir` field of your config to reflect your local path.
48 changes: 48 additions & 0 deletions InnerEye-DataQuality/InnerEyeDataQuality/configs/config_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from typing import Dict, List, Optional, Union

import yacs.config


class ConfigNode(yacs.config.CfgNode):
def __init__(self, init_dict: Optional[Dict] = None, key_list: Optional[List] = None, new_allowed: bool = False):
super().__init__(init_dict, key_list, new_allowed)

def __str__(self) -> str:
def _indent(s_: str, num_spaces: int) -> Union[str, List[str]]:
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s) # type: ignore
s = first + '\n' + s # type: ignore
return s

r = ''
s = []
for k, v in self.items():
separator = '\n' if isinstance(v, ConfigNode) else ' '
if isinstance(v, str) and not v:
v = '\'\''
attr_str = f'{str(k)}:{separator}{str(v)}'
attr_str = _indent(attr_str, 2) # type: ignore
s.append(attr_str)
r += '\n'.join(s)
return r

def as_dict(self) -> Dict:
def convert_to_dict(node: ConfigNode) -> Dict:
if not isinstance(node, ConfigNode):
return node
else:
dic = dict()
for k, v in node.items():
dic[k] = convert_to_dict(v)
return dic

return convert_to_dict(self)
Loading