diff --git a/.github/workflows/ci-lingvo.yml b/.github/workflows/ci-lingvo.yml index ab7ab24822..631f3f539f 100644 --- a/.github/workflows/ci-lingvo.yml +++ b/.github/workflows/ci-lingvo.yml @@ -50,7 +50,7 @@ jobs: sudo apt-get update sudo apt-get -y -q install ffmpeg libavcodec-extra python -m pip install --upgrade pip setuptools wheel - pip install -q -r <(sed '/^scipy/d;/^matplotlib/d;/^pandas/d;/^statsmodels/d;/^numba/d;/^jax/d;/^h5py/d;/^Pillow/d;/^pytest/d;/^pytest-mock/d;/^torch/d;/^torchaudio/d;/^torchvision/d;/^xgboost/d;/^requests/d;/^tensorflow/d;/^keras/d;/^kornia/d;/^librosa/d;/^tqdm/d' requirements_test.txt) + pip install -q -r <(sed '/^scipy/d;/^matplotlib/d;/^pandas/d;/^statsmodels/d;/^numba/d;/^jax/d;/^h5py/d;/^Pillow/d;/^pytest/d;/^pytest-mock/d;/^torch/d;/^torchaudio/d;/^torchvision/d;/^xgboost/d;/^requests/d;/^tensorflow/d;/^keras/d;/^kornia/d;/^librosa/d;/^tqdm/d;/^timm/d' requirements_test.txt) pip install scipy==1.5.4 pip install matplotlib==3.3.4 pip install pandas==1.1.5 diff --git a/art/estimators/certification/__init__.py b/art/estimators/certification/__init__.py index 92e79a0233..83a69eb514 100644 --- a/art/estimators/certification/__init__.py +++ b/art/estimators/certification/__init__.py @@ -6,7 +6,6 @@ from art.estimators.certification.randomized_smoothing.numpy import NumpyRandomizedSmoothing from art.estimators.certification.randomized_smoothing.tensorflow import TensorFlowV2RandomizedSmoothing from art.estimators.certification.randomized_smoothing.pytorch import PyTorchRandomizedSmoothing -from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin from art.estimators.certification.derandomized_smoothing.pytorch import PyTorchDeRandomizedSmoothing from art.estimators.certification.derandomized_smoothing.tensorflow import TensorFlowV2DeRandomizedSmoothing from art.estimators.certification.object_seeker.object_seeker import ObjectSeekerMixin diff --git a/art/estimators/certification/derandomized_smoothing/__init__.py b/art/estimators/certification/derandomized_smoothing/__init__.py index 1eea6eb3da..69753f4f39 100644 --- a/art/estimators/certification/derandomized_smoothing/__init__.py +++ b/art/estimators/certification/derandomized_smoothing/__init__.py @@ -1,6 +1,5 @@ """ DeRandomized smoothing estimators. """ -from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin from art.estimators.certification.derandomized_smoothing.pytorch import PyTorchDeRandomizedSmoothing from art.estimators.certification.derandomized_smoothing.tensorflow import TensorFlowV2DeRandomizedSmoothing diff --git a/art/estimators/certification/derandomized_smoothing/ablators/__init__.py b/art/estimators/certification/derandomized_smoothing/ablators/__init__.py new file mode 100644 index 0000000000..23715d4aba --- /dev/null +++ b/art/estimators/certification/derandomized_smoothing/ablators/__init__.py @@ -0,0 +1,12 @@ +""" +This module contains the ablators for the certified smoothing approaches. +""" +import importlib + +from art.estimators.certification.derandomized_smoothing.ablators.tensorflow import ColumnAblator, BlockAblator + +if importlib.util.find_spec("torch") is not None: + from art.estimators.certification.derandomized_smoothing.ablators.pytorch import ( + ColumnAblatorPyTorch, + BlockAblatorPyTorch, + ) diff --git a/art/estimators/certification/derandomized_smoothing/ablators/ablate.py b/art/estimators/certification/derandomized_smoothing/ablators/ablate.py new file mode 100644 index 0000000000..3970b5b862 --- /dev/null +++ b/art/estimators/certification/derandomized_smoothing/ablators/ablate.py @@ -0,0 +1,90 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This module implements the abstract base class for the ablators. +""" +from __future__ import absolute_import, division, print_function, unicode_literals + +from abc import ABC, abstractmethod +from typing import Optional, Tuple, Union, TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + # pylint: disable=C0412 + import tensorflow as tf + import torch + + +class BaseAblator(ABC): + """ + Base class defining the methods used for the ablators. + """ + + @abstractmethod + def __call__( + self, x: np.ndarray, column_pos: Optional[Union[int, list]] = None, row_pos: Optional[Union[int, list]] = None + ) -> np.ndarray: + """ + Ablate the image x at location specified by "column_pos" for the case of column ablation or at the location + specified by "column_pos" and "row_pos" in the case of block ablation. + + :param x: input image. + :param column_pos: column position to specify where to retain the image + :param row_pos: row position to specify where to retain the image. Not used for ablation type "column". + """ + raise NotImplementedError + + @abstractmethod + def certify( + self, pred_counts: np.ndarray, size_to_certify: int, label: Union[np.ndarray, "tf.Tensor"] + ) -> Union[Tuple["tf.Tensor", "tf.Tensor", "tf.Tensor"], Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]]: + """ + Checks if based on the predictions supplied the classifications over the ablated datapoints result in a + certified prediction against a patch attack of size size_to_certify. + + :param pred_counts: The cumulative predictions of the classifier over the ablation locations. + :param size_to_certify: The size of the patch to check against. + :param label: ground truth labels + """ + raise NotImplementedError + + @abstractmethod + def ablate(self, x: np.ndarray, column_pos: int, row_pos: int) -> Union[np.ndarray, "torch.Tensor"]: + """ + Ablate the image x at location specified by "column_pos" for the case of column ablation or at the location + specified by "column_pos" and "row_pos" in the case of block ablation. + + :param x: input image. + :param column_pos: column position to specify where to retain the image + :param row_pos: row position to specify where to retain the image. Not used for ablation type "column". + """ + raise NotImplementedError + + @abstractmethod + def forward( + self, x: np.ndarray, column_pos: Optional[int] = None, row_pos: Optional[int] = None + ) -> Union[np.ndarray, "torch.Tensor"]: + """ + Ablate batch of data at locations specified by column_pos and row_pos + + :param x: input image. + :param column_pos: column position to specify where to retain the image + :param row_pos: row position to specify where to retain the image. Not used for ablation type "column". + """ + raise NotImplementedError diff --git a/art/estimators/certification/derandomized_smoothing/ablators/pytorch.py b/art/estimators/certification/derandomized_smoothing/ablators/pytorch.py new file mode 100644 index 0000000000..1f1ad1aeec --- /dev/null +++ b/art/estimators/certification/derandomized_smoothing/ablators/pytorch.py @@ -0,0 +1,401 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This module implements Certified Patch Robustness via Smoothed Vision Transformers + +| Paper link Accepted version: + https://openaccess.thecvf.com/content/CVPR2022/papers/Salman_Certified_Patch_Robustness_via_Smoothed_Vision_Transformers_CVPR_2022_paper.pdf + +| Paper link Arxiv version (more detail): https://arxiv.org/pdf/2110.07719.pdf +""" + +from typing import Optional, Union, Tuple +import random + +import numpy as np +import torch + +from art.estimators.certification.derandomized_smoothing.ablators.ablate import BaseAblator + + +class UpSamplerPyTorch(torch.nn.Module): + """ + Resizes datasets to the specified size. + Usually for upscaling datasets like CIFAR to Imagenet format + """ + + def __init__(self, input_size: int, final_size: int) -> None: + """ + Creates an upsampler to make the supplied data match the pre-trained ViT format + + :param input_size: Size of the current input data + :param final_size: Desired final size + """ + super().__init__() + self.upsample = torch.nn.Upsample(scale_factor=final_size / input_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass though the upsampler. + + :param x: Input data + :return: The upsampled input data + """ + return self.upsample(x) + + +class ColumnAblatorPyTorch(torch.nn.Module, BaseAblator): + """ + Pure Pytorch implementation of stripe/column ablation. + """ + + def __init__( + self, + ablation_size: int, + channels_first: bool, + mode: str, + to_reshape: bool, + ablation_mode: str = "column", + original_shape: Optional[Tuple] = None, + output_shape: Optional[Tuple] = None, + algorithm: str = "salman2021", + device_type: str = "gpu", + ): + """ + Creates a column ablator + + :param ablation_size: The size of the column we will retain. + :param channels_first: If the input is in channels first format. Currently required to be True. + :param mode: If we are running the algorithm using a CNN or VIT. + :param to_reshape: If the input requires reshaping. + :param ablation_mode: The type of ablation to perform. + :param original_shape: Original shape of the input. + :param output_shape: Input shape expected by the ViT. Usually means upscaling the input to 224 x 224. + :param algorithm: Either 'salman2021' or 'levine2020'. + :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. + """ + super().__init__() + + self.ablation_size = ablation_size + self.channels_first = channels_first + self.to_reshape = to_reshape + self.add_ablation_mask = False + self.additional_channels = False + self.algorithm = algorithm + self.original_shape = original_shape + self.ablation_mode = ablation_mode + + if self.algorithm == "levine2020": + self.additional_channels = True + if self.algorithm == "salman2021" and mode == "ViT": + self.add_ablation_mask = True + + if device_type == "cpu" or not torch.cuda.is_available(): + self.device = torch.device("cpu") + else: # pragma: no cover + cuda_idx = torch.cuda.current_device() + self.device = torch.device(f"cuda:{cuda_idx}") + + if original_shape is not None and output_shape is not None: + self.upsample = UpSamplerPyTorch(input_size=original_shape[1], final_size=output_shape[1]) + + def ablate( + self, x: Union[torch.Tensor, np.ndarray], column_pos: int, row_pos: Optional[int] = None + ) -> torch.Tensor: + """ + Ablates the input column wise + + :param x: Input data + :param column_pos: location to start the retained column. NB, if row_ablation_mode is true then this will + be used to act on the rows through transposing the image in ColumnAblatorPyTorch.forward + :param row_pos: Unused. + :return: The ablated input with 0s where the ablation occurred + """ + k = self.ablation_size + + if isinstance(x, np.ndarray): + x = torch.from_numpy(x).to(self.device) + + if column_pos + k > x.shape[-1]: + x[:, :, :, (column_pos + k) % x.shape[-1] : column_pos] = 0.0 + else: + x[:, :, :, :column_pos] = 0.0 + x[:, :, :, column_pos + k :] = 0.0 + return x + + def forward( + self, x: Union[torch.Tensor, np.ndarray], column_pos: Optional[int] = None, row_pos=None + ) -> torch.Tensor: + """ + Forward pass though the ablator. We insert a new channel to keep track of the ablation location. + + :param x: Input data + :param column_pos: The start position of the albation + :param row_pos: Unused. + :return: The albated input with an extra channel indicating the location of the ablation + """ + if row_pos is not None: + raise ValueError("Use column_pos for a ColumnAblator. The row_pos argument is unused") + + if self.original_shape is not None and x.shape[1] != self.original_shape[0] and self.algorithm == "salman2021": + raise ValueError(f"Ablator expected {self.original_shape[0]} input channels. Recived shape of {x.shape[1]}") + + if isinstance(x, np.ndarray): + x = torch.from_numpy(x).to(self.device) + + if self.add_ablation_mask: + ones = torch.torch.ones_like(x[:, 0:1, :, :]).to(self.device) + x = torch.cat([x, ones], dim=1) + + if self.additional_channels: + x = torch.cat([x, 1.0 - x], dim=1) + + if self.original_shape is not None and x.shape[1] != self.original_shape[0] and self.additional_channels: + raise ValueError( + f"Ablator expected {self.original_shape[0]} input channels. Received shape of {x.shape[1]}" + ) + + if self.ablation_mode == "row": + x = torch.transpose(x, 3, 2) + + if column_pos is None: + column_pos = random.randint(0, x.shape[3]) + + ablated_x = self.ablate(x, column_pos=column_pos) + + if self.ablation_mode == "row": + ablated_x = torch.transpose(ablated_x, 3, 2) + + if self.to_reshape: + ablated_x = self.upsample(ablated_x) + return ablated_x + + def certify( + self, + pred_counts: Union[torch.Tensor, np.ndarray], + size_to_certify: int, + label: Union[torch.Tensor, np.ndarray], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Performs certification of the predictions + + :param pred_counts: The model predictions over the ablated data. + :param size_to_certify: The patch size we wish to check certification against + :param label: The ground truth labels + :return: A tuple consisting of: the certified predictions, + the predictions which were certified and also correct, + and the most predicted class across the different ablations on the input. + """ + + if isinstance(pred_counts, np.ndarray): + pred_counts = torch.from_numpy(pred_counts).to(self.device) + + if isinstance(label, np.ndarray): + label = torch.from_numpy(label).to(self.device) + + num_of_classes = pred_counts.shape[-1] + + # NB! argmax and kthvalue handle ties between predicted counts differently. + # The original implementation: https://github.com/MadryLab/smoothed-vit/blob/main/src/utils/smoothing.py#L98 + # uses argmax for the model predictions + # (later called y_smoothed https://github.com/MadryLab/smoothed-vit/blob/main/src/utils/smoothing.py#L230) + # and kthvalue for the certified predictions. + # to be consistent with the original implementation we also follow this here. + top_predicted_class_argmax = torch.argmax(pred_counts, dim=1) + + top_class_counts, top_predicted_class = pred_counts.kthvalue(num_of_classes, dim=1) + second_class_counts, second_predicted_class = pred_counts.kthvalue(num_of_classes - 1, dim=1) + + cert = (top_class_counts - second_class_counts) > 2 * (size_to_certify + self.ablation_size - 1) + + if self.algorithm == "levine2020": + tie_break_certs = ( + (top_class_counts - second_class_counts) == 2 * (size_to_certify + self.ablation_size - 1) + ) & (top_predicted_class < second_predicted_class) + cert = torch.logical_or(cert, tie_break_certs) + + cert_and_correct = cert & (label == top_predicted_class) + + return cert, cert_and_correct, top_predicted_class_argmax + + +class BlockAblatorPyTorch(torch.nn.Module, BaseAblator): + """ + Pure Pytorch implementation of block ablation. + """ + + def __init__( + self, + ablation_size: int, + channels_first: bool, + mode: str, + to_reshape: bool, + original_shape: Optional[Tuple] = None, + output_shape: Optional[Tuple] = None, + algorithm: str = "salman2021", + device_type: str = "gpu", + ): + """ + Creates a column ablator + + :param ablation_size: The size of the block we will retain. + :param channels_first: If the input is in channels first format. Currently required to be True. + :param mode: If we are running the algorithm using a CNN or VIT. + :param to_reshape: If the input requires reshaping. + :param original_shape: Original shape of the input. + :param output_shape: Input shape expected by the ViT. Usually means upscaling the input to 224 x 224. + :param algorithm: Either 'salman2021' or 'levine2020'. + :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. + """ + super().__init__() + + self.ablation_size = ablation_size + self.channels_first = channels_first + self.to_reshape = to_reshape + self.add_ablation_mask = False + self.additional_channels = False + self.algorithm = algorithm + self.original_shape = original_shape + + if self.algorithm == "levine2020": + self.additional_channels = True + if self.algorithm == "salman2021" and mode == "ViT": + self.add_ablation_mask = True + + if device_type == "cpu" or not torch.cuda.is_available(): + self.device = torch.device("cpu") + else: # pragma: no cover + cuda_idx = torch.cuda.current_device() + self.device = torch.device(f"cuda:{cuda_idx}") + + if original_shape is not None and output_shape is not None: + self.upsample = UpSamplerPyTorch(input_size=original_shape[1], final_size=output_shape[1]) + + def ablate(self, x: Union[torch.Tensor, np.ndarray], column_pos: int, row_pos: int) -> torch.Tensor: + """ + Ablates the input block wise + + :param x: Input data + :param column_pos: The start position of the albation + :param row_pos: The row start position of the albation + :return: The ablated input with 0s where the ablation occurred + """ + + if isinstance(x, np.ndarray): + x = torch.from_numpy(x).to(self.device) + + k = self.ablation_size + # Column ablations + if column_pos + k > x.shape[-1]: + x[:, :, :, (column_pos + k) % x.shape[-1] : column_pos] = 0.0 + else: + x[:, :, :, :column_pos] = 0.0 + x[:, :, :, column_pos + k :] = 0.0 + + # Row ablations + if row_pos + k > x.shape[-2]: + x[:, :, (row_pos + k) % x.shape[-2] : row_pos, :] = 0.0 + else: + x[:, :, :row_pos, :] = 0.0 + x[:, :, row_pos + k :, :] = 0.0 + return x + + def forward( + self, x: Union[torch.Tensor, np.ndarray], column_pos: Optional[int] = None, row_pos: Optional[int] = None + ) -> torch.Tensor: + """ + Forward pass though the ablator. We insert a new channel to keep track of the ablation location. + + :param x: Input data + :param column_pos: The start position of the albation + :return: The albated input with an extra channel indicating the location of the ablation if running in + """ + if self.original_shape is not None and x.shape[1] != self.original_shape[0] and self.algorithm == "salman2021": + raise ValueError(f"Ablator expected {self.original_shape[0]} input channels. Recived shape of {x.shape[1]}") + + if column_pos is None: + column_pos = random.randint(0, x.shape[3]) + + if row_pos is None: + row_pos = random.randint(0, x.shape[2]) + + if isinstance(x, np.ndarray): + x = torch.from_numpy(x).to(self.device) + + if self.add_ablation_mask: + ones = torch.torch.ones_like(x[:, 0:1, :, :]).to(self.device) + x = torch.cat([x, ones], dim=1) + + if self.additional_channels: + x = torch.cat([x, 1.0 - x], dim=1) + + if self.original_shape is not None and x.shape[1] != self.original_shape[0] and self.additional_channels: + raise ValueError(f"Ablator expected {self.original_shape[0]} input channels. Recived shape of {x.shape[1]}") + + ablated_x = self.ablate(x, column_pos=column_pos, row_pos=row_pos) + + if self.to_reshape: + ablated_x = self.upsample(ablated_x) + return ablated_x + + def certify( + self, + pred_counts: Union[torch.Tensor, np.ndarray], + size_to_certify: int, + label: Union[torch.Tensor, np.ndarray], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Performs certification of the predictions + + :param pred_counts: The model predictions over the ablated data. + :param size_to_certify: The patch size we wish to check certification against + :param label: The ground truth labels + :return: A tuple consisting of: the certified predictions, + the predictions which were certified and also correct, + and the most predicted class across the different ablations on the input. + """ + + if isinstance(pred_counts, np.ndarray): + pred_counts = torch.from_numpy(pred_counts).to(self.device) + + if isinstance(label, np.ndarray): + label = torch.from_numpy(label).to(self.device) + + # NB! argmax and kthvalue handle ties between predicted counts differently. + # The original implementation: https://github.com/MadryLab/smoothed-vit/blob/main/src/utils/smoothing.py#L145 + # uses argmax for the model predictions + # (later called y_smoothed https://github.com/MadryLab/smoothed-vit/blob/main/src/utils/smoothing.py#L230) + # and kthvalue for the certified predictions. + # to be consistent with the original implementation we also follow this here. + top_predicted_class_argmax = torch.argmax(pred_counts, dim=1) + + num_of_classes = pred_counts.shape[-1] + + top_class_counts, top_predicted_class = pred_counts.kthvalue(num_of_classes, dim=1) + second_class_counts, second_predicted_class = pred_counts.kthvalue(num_of_classes - 1, dim=1) + + cert = (top_class_counts - second_class_counts) > 2 * (size_to_certify + self.ablation_size - 1) ** 2 + + cert_and_correct = cert & (label == top_predicted_class) + + if self.algorithm == "levine2020": + tie_break_certs = ( + (top_class_counts - second_class_counts) == 2 * (size_to_certify + self.ablation_size - 1) ** 2 + ) & (top_predicted_class < second_predicted_class) + cert = torch.logical_or(cert, tie_break_certs) + return cert, cert_and_correct, top_predicted_class_argmax diff --git a/art/estimators/certification/derandomized_smoothing/derandomized_smoothing.py b/art/estimators/certification/derandomized_smoothing/ablators/tensorflow.py similarity index 51% rename from art/estimators/certification/derandomized_smoothing/derandomized_smoothing.py rename to art/estimators/certification/derandomized_smoothing/ablators/tensorflow.py index 42a31ca418..e4b927358e 100644 --- a/art/estimators/certification/derandomized_smoothing/derandomized_smoothing.py +++ b/art/estimators/certification/derandomized_smoothing/ablators/tensorflow.py @@ -23,176 +23,16 @@ from __future__ import absolute_import, division, print_function, unicode_literals -from abc import ABC, abstractmethod -from typing import Optional, Union, TYPE_CHECKING +from typing import Optional, Union, Tuple, TYPE_CHECKING import random import numpy as np -if TYPE_CHECKING: - from art.utils import ABLATOR_TYPE - - -class DeRandomizedSmoothingMixin(ABC): - """ - Implementation of (De)Randomized Smoothing applied to classifier predictions as introduced - in Levine et al. (2020). - - | Paper link: https://arxiv.org/abs/2002.10733 - """ - - def __init__( - self, - ablation_type: str, - ablation_size: int, - threshold: float, - logits: bool, - channels_first: bool, - *args, - **kwargs, - ) -> None: - """ - Create a derandomized smoothing wrapper. - - :param ablation_type: The type of ablations to perform. Currently must be either "column", "row", or "block" - :param ablation_size: Size of the retained image patch. - An int specifying the width of the column for column ablation - Or an int specifying the height/width of a square for block ablation - :param threshold: The minimum threshold to count a prediction. - :param logits: if the model returns logits or normalized probabilities - :param channels_first: If the channels are first or last. - """ - super().__init__(*args, **kwargs) # type: ignore - self.ablation_type = ablation_type - self.logits = logits - self.threshold = threshold - self._channels_first = channels_first - if TYPE_CHECKING: - self.ablator: ABLATOR_TYPE # pylint: disable=used-before-assignment - - if self.ablation_type in {"column", "row"}: - row_ablation_mode = self.ablation_type == "row" - self.ablator = ColumnAblator( - ablation_size=ablation_size, channels_first=self._channels_first, row_ablation_mode=row_ablation_mode - ) - elif self.ablation_type == "block": - self.ablator = BlockAblator(ablation_size=ablation_size, channels_first=self._channels_first) - else: - raise ValueError("Ablation type not supported. Must be either column or block") - - def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray: - """ - Perform prediction for a batch of inputs. - - :param x: Input samples. - :param batch_size: Size of batches. - :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode. - :return: Array of predictions of shape `(nb_inputs, nb_classes)`. - """ - raise NotImplementedError - - def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs) -> np.ndarray: - """ - Performs cumulative predictions over every ablation location - - :param x: Unablated image - :param batch_size: the batch size for the prediction - :param training_mode: if to run the classifier in training mode - :return: cumulative predictions after sweeping over all the ablation configurations. - """ - if self._channels_first: - columns_in_data = x.shape[-1] - rows_in_data = x.shape[-2] - else: - columns_in_data = x.shape[-2] - rows_in_data = x.shape[-3] - - if self.ablation_type in {"column", "row"}: - if self.ablation_type == "column": - ablate_over_range = columns_in_data - else: - # image will be transposed, so loop over the number of rows - ablate_over_range = rows_in_data - - for ablation_start in range(ablate_over_range): - ablated_x = self.ablator.forward(np.copy(x), column_pos=ablation_start) - if ablation_start == 0: - preds = self._predict_classifier( - ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs - ) - else: - preds += self._predict_classifier( - ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs - ) - elif self.ablation_type == "block": - for xcorner in range(rows_in_data): - for ycorner in range(columns_in_data): - ablated_x = self.ablator.forward(np.copy(x), row_pos=xcorner, column_pos=ycorner) - if ycorner == 0 and xcorner == 0: - preds = self._predict_classifier( - ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs - ) - else: - preds += self._predict_classifier( - ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs - ) - return preds - - -class BaseAblator(ABC): - """ - Base class defining the methods used for the ablators. - """ - - @abstractmethod - def __call__( - self, x: np.ndarray, column_pos: Optional[Union[int, list]] = None, row_pos: Optional[Union[int, list]] = None - ) -> np.ndarray: - """ - Ablate the image x at location specified by "column_pos" for the case of column ablation or at the location - specified by "column_pos" and "row_pos" in the case of block ablation. - - :param x: input image. - :param column_pos: column position to specify where to retain the image - :param row_pos: row position to specify where to retain the image. Not used for ablation type "column". - """ - raise NotImplementedError - - @abstractmethod - def certify(self, preds: np.ndarray, size_to_certify: int): - """ - Checks if based on the predictions supplied the classifications over the ablated datapoints result in a - certified prediction against a patch attack of size size_to_certify. - - :param preds: The cumulative predictions of the classifier over the ablation locations. - :param size_to_certify: The size of the patch to check against. - """ - raise NotImplementedError - - @abstractmethod - def ablate(self, x: np.ndarray, column_pos: int, row_pos: int) -> np.ndarray: - """ - Ablate the image x at location specified by "column_pos" for the case of column ablation or at the location - specified by "column_pos" and "row_pos" in the case of block ablation. +from art.estimators.certification.derandomized_smoothing.ablators.ablate import BaseAblator - :param x: input image. - :param column_pos: column position to specify where to retain the image - :param row_pos: row position to specify where to retain the image. Not used for ablation type "column". - """ - raise NotImplementedError - - @abstractmethod - def forward( - self, x: np.ndarray, column_pos: Optional[Union[int, list]] = None, row_pos: Optional[Union[int, list]] = None - ) -> np.ndarray: - """ - Ablate batch of data at locations specified by column_pos and row_pos - - :param x: input image. - :param column_pos: column position to specify where to retain the image - :param row_pos: row position to specify where to retain the image. Not used for ablation type "column". - """ - raise NotImplementedError +if TYPE_CHECKING: + # pylint: disable=C0412 + import tensorflow as tf class ColumnAblator(BaseAblator): @@ -230,27 +70,50 @@ def __call__( """ return self.forward(x=x, column_pos=column_pos) - def certify(self, preds: np.ndarray, size_to_certify: int) -> np.ndarray: + def certify( + self, pred_counts: "tf.Tensor", size_to_certify: int, label: Union[np.ndarray, "tf.Tensor"] + ) -> Tuple["tf.Tensor", "tf.Tensor", "tf.Tensor"]: """ Checks if based on the predictions supplied the classifications over the ablated datapoints result in a certified prediction against a patch attack of size size_to_certify. :param preds: The cumulative predictions of the classifier over the ablation locations. :param size_to_certify: The size of the patch to check against. - :return: Array of bools indicating if a point is certified against the given patch dimensions. + :param label: Ground truth labels + :return: A tuple consisting of: the certified predictions, + the predictions which were certified and also correct, + and the most predicted class across the different ablations on the input. """ - indices = np.argsort(-preds, axis=1, kind="stable") - values = np.take_along_axis(np.copy(preds), indices, axis=1) + import tensorflow as tf - num_affected_classifications = size_to_certify + self.ablation_size - 1 + result = tf.math.top_k(pred_counts, k=2) - margin = values[:, 0] - values[:, 1] + top_predicted_class, second_predicted_class = result.indices[:, 0], result.indices[:, 1] + top_class_counts, second_class_counts = result.values[:, 0], result.values[:, 1] - certs = margin > 2 * num_affected_classifications - tie_break_certs = (margin == 2 * num_affected_classifications) & (indices[:, 0] < indices[:, 1]) - return np.logical_or(certs, tie_break_certs) + certs = (top_class_counts - second_class_counts) > 2 * (size_to_certify + self.ablation_size - 1) - def ablate(self, x: np.ndarray, column_pos: int, row_pos=None) -> np.ndarray: + tie_break_certs = ( + (top_class_counts - second_class_counts) == 2 * (size_to_certify + self.ablation_size - 1) + ) & (top_predicted_class < second_predicted_class) + cert = tf.math.logical_or(certs, tie_break_certs) + + # NB, newer versions of pylint do not require the disable. + if label.ndim > 1: + cert_and_correct = cert & ( + tf.math.argmax(label, axis=1) + == tf.cast( # pylint: disable=E1120, E1123 + top_predicted_class, dtype=tf.math.argmax(label, axis=1).dtype + ) + ) + else: + cert_and_correct = cert & ( + label == tf.cast(top_predicted_class, dtype=label.dtype) # pylint: disable=E1120, E1123 + ) + + return cert, cert_and_correct, top_predicted_class + + def ablate(self, x: np.ndarray, column_pos: int, row_pos: Optional[int] = None) -> np.ndarray: """ Ablates the image only retaining a column starting at "pos" of width "self.ablation_size" @@ -348,24 +211,47 @@ def __call__( """ return self.forward(x=x, row_pos=row_pos, column_pos=column_pos) - def certify(self, preds: np.ndarray, size_to_certify: int) -> np.ndarray: + def certify( + self, pred_counts: Union["tf.Tensor", np.ndarray], size_to_certify: int, label: Union[np.ndarray, "tf.Tensor"] + ) -> Tuple["tf.Tensor", "tf.Tensor", "tf.Tensor"]: """ Checks if based on the predictions supplied the classifications over the ablated datapoints result in a certified prediction against a patch attack of size size_to_certify. - :param preds: The cumulative predictions of the classifier over the ablation locations. + :param pred_counts: The cumulative predictions of the classifier over the ablation locations. :param size_to_certify: The size of the patch to check against. - :return: Array of bools indicating if a point is certified against the given patch dimensions. - """ - indices = np.argsort(-preds, axis=1, kind="stable") - values = np.take_along_axis(np.copy(preds), indices, axis=1) - margin = values[:, 0] - values[:, 1] - - num_affected_classifications = (size_to_certify + self.ablation_size - 1) ** 2 + :param label: Ground truth labels + :return: A tuple consisting of: the certified predictions, + the predictions which were certified and also correct, + and the most predicted class across the different ablations on the input. + """ + import tensorflow as tf + + result = tf.math.top_k(pred_counts, k=2) + + top_predicted_class, second_predicted_class = result.indices[:, 0], result.indices[:, 1] + top_class_counts, second_class_counts = result.values[:, 0], result.values[:, 1] + + certs = (top_class_counts - second_class_counts) > 2 * (size_to_certify + self.ablation_size - 1) ** 2 + tie_break_certs = ( + (top_class_counts - second_class_counts) == 2 * (size_to_certify + self.ablation_size - 1) ** 2 + ) & (top_predicted_class < second_predicted_class) + cert = tf.math.logical_or(certs, tie_break_certs) + + # NB, newer versions of pylint do not require the disable. + if label.ndim > 1: + cert_and_correct = cert & ( + tf.math.argmax(label, axis=1) + == tf.cast( # pylint: disable=E1120, E1123 + top_predicted_class, dtype=tf.math.argmax(label, axis=1).dtype + ) + ) + else: + cert_and_correct = cert & ( + label == tf.cast(top_predicted_class, dtype=label.dtype) # pylint: disable=E1120, E1123 + ) - certs = margin > 2 * num_affected_classifications - tie_break_certs = (margin == 2 * num_affected_classifications) & (indices[:, 0] < indices[:, 1]) - return np.logical_or(certs, tie_break_certs) + return cert, cert_and_correct, top_predicted_class def forward( self, @@ -415,40 +301,17 @@ def ablate(self, x: np.ndarray, column_pos: int, row_pos: int) -> np.ndarray: :return: Data ablated at all locations aside from the specified block. """ k = self.ablation_size - num_of_image_columns = x.shape[3] - num_of_image_rows = x.shape[2] - - if row_pos + k > x.shape[2] and column_pos + k > x.shape[3]: - start_of_ablation = column_pos + k - num_of_image_columns - x[:, :, :, start_of_ablation:column_pos] = 0.0 - - start_of_ablation = row_pos + k - num_of_image_rows - x[:, :, start_of_ablation:row_pos, :] = 0.0 - - # only the row wraps - elif row_pos + k > x.shape[2] and column_pos + k <= x.shape[3]: - x[:, :, :, :column_pos] = 0.0 - x[:, :, :, column_pos + k :] = 0.0 - - start_of_ablation = row_pos + k - num_of_image_rows - x[:, :, start_of_ablation:row_pos, :] = 0.0 - - # only column wraps - elif row_pos + k <= x.shape[2] and column_pos + k > x.shape[3]: - start_of_ablation = column_pos + k - num_of_image_columns - x[:, :, :, start_of_ablation:column_pos] = 0.0 - - x[:, :, :row_pos, :] = 0.0 - x[:, :, row_pos + k :, :] = 0.0 - - # neither wraps - elif row_pos + k <= x.shape[2] and column_pos + k <= x.shape[3]: + # Column ablations + if column_pos + k > x.shape[-1]: + x[:, :, :, (column_pos + k) % x.shape[-1] : column_pos] = 0.0 + else: x[:, :, :, :column_pos] = 0.0 x[:, :, :, column_pos + k :] = 0.0 + # Row ablations + if row_pos + k > x.shape[-2]: + x[:, :, (row_pos + k) % x.shape[-2] : row_pos, :] = 0.0 + else: x[:, :, :row_pos, :] = 0.0 x[:, :, row_pos + k :, :] = 0.0 - else: - raise ValueError(f"Ablation failed on row: {row_pos} and column: {column_pos} with size {k}") - return x diff --git a/art/estimators/certification/derandomized_smoothing/derandomized.py b/art/estimators/certification/derandomized_smoothing/derandomized.py new file mode 100644 index 0000000000..9e2ee6ca0d --- /dev/null +++ b/art/estimators/certification/derandomized_smoothing/derandomized.py @@ -0,0 +1,69 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This module implements (De)Randomized Smoothing certifications against adversarial patches. + +| Paper link: https://arxiv.org/abs/2110.07719 + +| Paper link: https://arxiv.org/abs/2002.10733 +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +from abc import ABC, abstractmethod +import numpy as np + + +class DeRandomizedSmoothingMixin(ABC): + """ + Mixin class for smoothed estimators. + """ + + def __init__( + self, + *args, + **kwargs, + ) -> None: + """ + Create a derandomized smoothing wrapper. + """ + super().__init__(*args, **kwargs) # type: ignore + + @abstractmethod + def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray: + """ + Perform prediction for a batch of inputs. + + :param x: Input samples. + :param batch_size: Size of batches. + :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode. + :return: Array of predictions of shape `(nb_inputs, nb_classes)`. + """ + raise NotImplementedError + + @abstractmethod + def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs) -> np.ndarray: + """ + Performs cumulative predictions over every ablation location + + :param x: Unablated image + :param batch_size: the batch size for the prediction + :param training_mode: if to run the classifier in training mode + :return: cumulative predictions after sweeping over all the ablation configurations. + """ + raise NotImplementedError diff --git a/art/estimators/certification/derandomized_smoothing/pytorch.py b/art/estimators/certification/derandomized_smoothing/pytorch.py index 4a184b3666..cd3e53243b 100644 --- a/art/estimators/certification/derandomized_smoothing/pytorch.py +++ b/art/estimators/certification/derandomized_smoothing/pytorch.py @@ -16,13 +16,24 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ -This module implements (De)Randomized Smoothing for Certifiable Defense against Patch Attacks +This module implements De-Randomized smoothing approaches PyTorch. + +(De)Randomized Smoothing for Certifiable Defense against Patch Attacks | Paper link: https://arxiv.org/abs/2002.10733 + +and + +Certified Patch Robustness via Smoothed Vision Transformers + +| Paper link Accepted version: + https://openaccess.thecvf.com/content/CVPR2022/papers/Salman_Certified_Patch_Robustness_via_Smoothed_Vision_Transformers_CVPR_2022_paper.pdf + +| Paper link Arxiv version (more detail): https://arxiv.org/pdf/2110.07719.pdf """ from __future__ import absolute_import, division, print_function, unicode_literals - +import importlib import logging from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING import random @@ -30,15 +41,16 @@ import numpy as np from tqdm import tqdm -from art.config import ART_NUMPY_DTYPE from art.estimators.classification.pytorch import PyTorchClassifier -from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin +from art.estimators.certification.derandomized_smoothing.derandomized import DeRandomizedSmoothingMixin from art.utils import check_and_transform_label_format if TYPE_CHECKING: # pylint: disable=C0412 import torch - + import torchvision + from timm.models.vision_transformer import VisionTransformer + from art.estimators.certification.derandomized_smoothing.vision_transformers.pytorch import PyTorchVisionTransformer from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE from art.defences.preprocessor import Preprocessor from art.defences.postprocessor import Postprocessor @@ -48,47 +60,64 @@ class PyTorchDeRandomizedSmoothing(DeRandomizedSmoothingMixin, PyTorchClassifier): """ - Implementation of (De)Randomized Smoothing applied to classifier predictions as introduced - in Levine et al. (2020). + Interface class for the two De-randomized smoothing approaches supported by ART for pytorch. - | Paper link: https://arxiv.org/abs/2002.10733 - """ + If a regular pytorch neural network is fed in then (De)Randomized Smoothing as introduced in Levine et al. (2020) + is used. - estimator_params = PyTorchClassifier.estimator_params + ["ablation_type", "ablation_size", "threshold", "logits"] + Otherwise, if a timm vision transfomer is fed in then Certified Patch Robustness via Smoothed Vision Transformers + as introduced in Salman et al. (2021) is used. + """ def __init__( self, - model: "torch.nn.Module", + model: Union[str, "VisionTransformer", "torch.nn.Module"], loss: "torch.nn.modules.loss._Loss", input_shape: Tuple[int, ...], nb_classes: int, - ablation_type: str, ablation_size: int, - threshold: float, - logits: bool, - optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore + algorithm: str = "salman2021", + ablation_type: str = "column", + replace_last_layer: Optional[bool] = None, + drop_tokens: bool = True, + load_pretrained: bool = True, + optimizer: Union[type, "torch.optim.Optimizer", None] = None, + optimizer_params: Optional[dict] = None, channels_first: bool = True, + threshold: Optional[float] = None, + logits: Optional[bool] = True, clip_values: Optional["CLIP_VALUES_TYPE"] = None, preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None, postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None, preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0), device_type: str = "gpu", + verbose: bool = True, + **kwargs, ): """ - Create a derandomized smoothing classifier. + Create a smoothed classifier. - :param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits - output should be preferred where possible to ensure attack efficiency. + :param model: Either a CNN or a VIT. For a ViT supply a string specifying which ViT architecture to load from + the ViT library, or a vision transformer already created with the + Pytorch Image Models (timm) library. To run Levine et al. (2020) provide a regular pytorch model. :param loss: The loss function for which to compute gradients for training. The target label must be raw - categorical, i.e. not converted to one-hot encoding. + categorical, i.e. not converted to one-hot encoding. :param input_shape: The shape of one input instance. :param nb_classes: The number of classes of the model. - :param ablation_type: The type of ablation to perform, must be either "column" or "block" - :param ablation_size: The size of the data portion to retain after ablation. Will be a column of size N for - "column" ablation type or a NxN square for ablation of type "block" - :param threshold: The minimum threshold to count a prediction. - :param logits: if the model returns logits or normalized probabilities + :param ablation_size: The size of the data portion to retain after ablation. + :param algorithm: Either 'salman2021' or 'levine2020'. For salman2021 we support ViTs and CNNs. For levine2020 + there is only CNN support. + :param replace_last_layer: ViT Specific. If to replace the last layer of the ViT with a fresh layer + matching the number of classes for the dataset to be examined. + Needed if going from the pre-trained imagenet models to fine-tune + on a dataset like CIFAR. + :param drop_tokens: ViT Specific. If to drop the fully ablated tokens in the ViT + :param load_pretrained: ViT Specific. If to load a pretrained model matching the ViT name. + Will only affect the ViT if a string name is passed to model rather than a ViT directly. :param optimizer: The optimizer used to train the classifier. + :param ablation_type: The type of ablation to perform. Either "column", "row", or "block" + :param threshold: Specific to Levine et al. The minimum threshold to count a prediction. + :param logits: Specific to Levine et al. If the model returns logits or normalized probabilities :param channels_first: Set channels first or last. :param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and maximum values allowed for features. If floats are provided, these will be used as the range of all @@ -101,52 +130,304 @@ def __init__( be divided by the second one. :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. """ - super().__init__( - model=model, - loss=loss, - input_shape=input_shape, - nb_classes=nb_classes, - optimizer=optimizer, - channels_first=channels_first, - clip_values=clip_values, - preprocessing_defences=preprocessing_defences, - postprocessing_defences=postprocessing_defences, - preprocessing=preprocessing, - device_type=device_type, - ablation_type=ablation_type, - ablation_size=ablation_size, - threshold=threshold, - logits=logits, - ) - def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray: import torch - x = x.astype(ART_NUMPY_DTYPE) - outputs = PyTorchClassifier.predict(self, x=x, batch_size=batch_size, training_mode=training_mode, **kwargs) + if not channels_first: + raise ValueError("Channels must be set to first") + logger.info("Running algorithm: %s", algorithm) + + # Default value for output shape + output_shape = input_shape + self.mode = None + if importlib.util.find_spec("timm") is not None and algorithm == "salman2021": + from timm.models.vision_transformer import VisionTransformer + + if isinstance(model, (VisionTransformer, str)): + import timm + from art.estimators.certification.derandomized_smoothing.vision_transformers.pytorch import ( + PyTorchVisionTransformer, + ) + + if replace_last_layer is None: + raise ValueError("If using ViTs please specify if the last layer should be replaced") + + # temporarily assign the original method to tmp_func + tmp_func = timm.models.vision_transformer._create_vision_transformer + + # overrride with ART's ViT creation function + timm.models.vision_transformer._create_vision_transformer = self.create_vision_transformer + if isinstance(model, str): + model = timm.create_model( + model, pretrained=load_pretrained, drop_tokens=drop_tokens, device_type=device_type + ) + if replace_last_layer: + model.head = torch.nn.Linear(model.head.in_features, nb_classes) + if isinstance(optimizer, type): + if optimizer_params is not None: + optimizer = optimizer(model.parameters(), **optimizer_params) + else: + raise ValueError("If providing an optimiser please also supply its parameters") + + elif isinstance(model, VisionTransformer): + pretrained_cfg = model.pretrained_cfg + supplied_state_dict = model.state_dict() + supported_models = self.get_models() + if pretrained_cfg["architecture"] not in supported_models: + raise ValueError( + "Architecture not supported. Use PyTorchDeRandomizedSmoothing.get_models() " + "to get the supported model architectures." + ) + model = timm.create_model( + pretrained_cfg["architecture"], drop_tokens=drop_tokens, device_type=device_type + ) + model.load_state_dict(supplied_state_dict) + if replace_last_layer: + model.head = torch.nn.Linear(model.head.in_features, nb_classes) + + if optimizer is not None: + if not isinstance(optimizer, torch.optim.Optimizer): + raise ValueError("Optimizer error: must be a torch.optim.Optimizer instance") + + converted_optimizer: Union[torch.optim.Adam, torch.optim.SGD] + opt_state_dict = optimizer.state_dict() + if isinstance(optimizer, torch.optim.Adam): + logging.info("Converting Adam Optimiser") + converted_optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + elif isinstance(optimizer, torch.optim.SGD): + logging.info("Converting SGD Optimiser") + converted_optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + else: + raise ValueError("Optimiser not supported for conversion") + converted_optimizer.load_state_dict(opt_state_dict) + + self.to_reshape = False + if not isinstance(model, PyTorchVisionTransformer): + raise ValueError("Vision transformer is not of PyTorchViT. Error occurred in PyTorchViT creation.") + + if model.default_cfg["input_size"][0] != input_shape[0]: + raise ValueError( + f'ViT requires {model.default_cfg["input_size"][0]} channel input,' + f" but {input_shape[0]} channels were provided." + ) + + if model.default_cfg["input_size"] != input_shape: + if verbose: + logger.warning( + " ViT expects input shape of: (%i, %i, %i) but (%i, %i, %i) specified as the input shape." + " The input will be rescaled to (%i, %i, %i)", + *model.default_cfg["input_size"], + *input_shape, + *model.default_cfg["input_size"], + ) - if not self.logits: - return np.asarray((outputs >= self.threshold)) - return np.asarray( - (torch.nn.functional.softmax(torch.from_numpy(outputs), dim=1) >= self.threshold).type(torch.int) + self.to_reshape = True + output_shape = model.default_cfg["input_size"] + + # set the method back to avoid unexpected side effects later on should timm need to be reused. + timm.models.vision_transformer._create_vision_transformer = tmp_func + self.mode = "ViT" + else: + if isinstance(model, torch.nn.Module): + self.mode = "CNN" + output_shape = input_shape + self.to_reshape = False + + elif algorithm == "levine2020": + if ablation_type is None or threshold is None or logits is None: + raise ValueError( + "If using CNN please specify if the model returns logits, " + " the prediction threshold, and ablation type" + ) + self.mode = "CNN" + # input channels are internally doubled. + input_shape = (input_shape[0] * 2, input_shape[1], input_shape[2]) + output_shape = input_shape + self.to_reshape = False + + if optimizer is None or isinstance(optimizer, torch.optim.Optimizer): + super().__init__( + model=model, + loss=loss, + input_shape=input_shape, + nb_classes=nb_classes, + optimizer=optimizer, + channels_first=channels_first, + clip_values=clip_values, + preprocessing_defences=preprocessing_defences, + postprocessing_defences=postprocessing_defences, + preprocessing=preprocessing, + device_type=device_type, + ) + else: + raise ValueError("Error occurred in optimizer creation") + + self.threshold = threshold + self.logits = logits + self.ablation_size = (ablation_size,) + self.algorithm = algorithm + self.ablation_type = ablation_type + if verbose: + logger.info(self.model) + + from art.estimators.certification.derandomized_smoothing.ablators.pytorch import ( + ColumnAblatorPyTorch, + BlockAblatorPyTorch, ) - def predict( - self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs - ) -> np.ndarray: # type: ignore + if TYPE_CHECKING: + self.ablator: Union[ColumnAblatorPyTorch, BlockAblatorPyTorch] + + if self.mode is None: + raise ValueError("Model type not recognized.") + + if ablation_type in {"column", "row"}: + self.ablator = ColumnAblatorPyTorch( + ablation_size=ablation_size, + channels_first=True, + ablation_mode=ablation_type, + to_reshape=self.to_reshape, + original_shape=input_shape, + output_shape=output_shape, + device_type=device_type, + algorithm=algorithm, + mode=self.mode, + ) + elif ablation_type == "block": + self.ablator = BlockAblatorPyTorch( + ablation_size=ablation_size, + channels_first=True, + to_reshape=self.to_reshape, + original_shape=input_shape, + output_shape=output_shape, + device_type=device_type, + algorithm=algorithm, + mode=self.mode, + ) + else: + raise ValueError(f"ablation_type of {ablation_type} not recognized. Must be either column, row, or block") + + @classmethod + def get_models(cls, generate_from_null: bool = False) -> List[str]: """ - Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations. + Return the supported model names to the user. - :param x: Input samples. - :param batch_size: Batch size. - :param training_mode: if to run the classifier in training mode - :return: Array of predictions of shape `(nb_inputs, nb_classes)`. + :param generate_from_null: If to re-check the creation of all the ViTs in timm from scratch. + :return: A list of compatible models """ - return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=training_mode, **kwargs) + import timm + import torch - def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None: - x = x.astype(ART_NUMPY_DTYPE) - return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs) + supported_models = [ + "vit_base_patch8_224", + "vit_base_patch16_18x2_224", + "vit_base_patch16_224", + "vit_base_patch16_224_miil", + "vit_base_patch16_384", + "vit_base_patch16_clip_224", + "vit_base_patch16_clip_384", + "vit_base_patch16_gap_224", + "vit_base_patch16_plus_240", + "vit_base_patch16_rpn_224", + "vit_base_patch16_xp_224", + "vit_base_patch32_224", + "vit_base_patch32_384", + "vit_base_patch32_clip_224", + "vit_base_patch32_clip_384", + "vit_base_patch32_clip_448", + "vit_base_patch32_plus_256", + "vit_giant_patch14_224", + "vit_giant_patch14_clip_224", + "vit_gigantic_patch14_224", + "vit_gigantic_patch14_clip_224", + "vit_huge_patch14_224", + "vit_huge_patch14_clip_224", + "vit_huge_patch14_clip_336", + "vit_huge_patch14_xp_224", + "vit_large_patch14_224", + "vit_large_patch14_clip_224", + "vit_large_patch14_clip_336", + "vit_large_patch14_xp_224", + "vit_large_patch16_224", + "vit_large_patch16_384", + "vit_large_patch32_224", + "vit_large_patch32_384", + "vit_medium_patch16_gap_240", + "vit_medium_patch16_gap_256", + "vit_medium_patch16_gap_384", + "vit_small_patch16_18x2_224", + "vit_small_patch16_36x1_224", + "vit_small_patch16_224", + "vit_small_patch16_384", + "vit_small_patch32_224", + "vit_small_patch32_384", + "vit_tiny_patch16_224", + "vit_tiny_patch16_384", + ] + + if not generate_from_null: + return supported_models + + supported = [] + unsupported = [] + + models = timm.list_models("vit_*") + pbar = tqdm(models) + + # store in case not re-assigned in the model creation due to unsuccessful creation + tmp_func = timm.models.vision_transformer._create_vision_transformer # pylint: disable=W0212 + + for model in pbar: + pbar.set_description(f"Testing {model} creation") + try: + _ = cls( + model=model, + loss=torch.nn.CrossEntropyLoss(), + optimizer=torch.optim.SGD, + optimizer_params={"lr": 0.01}, + input_shape=(3, 32, 32), + nb_classes=10, + ablation_size=4, + load_pretrained=False, + replace_last_layer=True, + verbose=False, + ) + supported.append(model) + except (TypeError, AttributeError): + unsupported.append(model) + timm.models.vision_transformer._create_vision_transformer = tmp_func # pylint: disable=W0212 + + if supported != supported_models: + logger.warning( + "Difference between the generated and fixed model list. Although not necessarily " + "an error, this may point to the timm library being updated." + ) + + return supported + + @staticmethod + def create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> "PyTorchVisionTransformer": + """ + Creates a vision transformer using PyTorchViT which controls the forward pass of the model + + :param variant: The name of the vision transformer to load + :param pretrained: If to load pre-trained weights + :return: A ViT with the required methods needed for ART + """ + + from timm.models._builder import build_model_with_cfg + from timm.models.vision_transformer import checkpoint_filter_fn + from art.estimators.certification.derandomized_smoothing.vision_transformers.pytorch import ( + PyTorchVisionTransformer, + ) + + return build_model_with_cfg( + PyTorchVisionTransformer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs, + ) def fit( # pylint: disable=W0221 self, @@ -157,10 +438,15 @@ def fit( # pylint: disable=W0221 training_mode: bool = True, drop_last: bool = False, scheduler: Optional[Any] = None, + update_batchnorm: bool = True, + batchnorm_update_epochs: int = 1, + transform: Optional["torchvision.transforms.transforms.Compose"] = None, + verbose: bool = True, **kwargs, ) -> None: """ Fit the classifier on the training set `(x, y)`. + :param x: Training data. :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of shape (nb_samples,). @@ -171,6 +457,13 @@ def fit( # pylint: disable=W0221 the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) :param scheduler: Learning rate scheduler to run at the start of every epoch. + :param update_batchnorm: ViT specific argument. + If to run the training data through the model to update any batch norm statistics prior + to training. Useful on small datasets when using pre-trained ViTs. + :param batchnorm_update_epochs: ViT specific argument. How many times to forward pass over the training data + to pre-adjust the batchnorm statistics. + :param transform: ViT specific argument. Torchvision compose of relevant augmentation transformations to apply. + :param verbose: if to display training progress bars :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch and providing it takes no effect. """ @@ -187,14 +480,14 @@ def fit( # pylint: disable=W0221 # Apply preprocessing x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) + if update_batchnorm and self.mode == "ViT": # VIT specific + self.update_batchnorm(x_preprocessed, batch_size, nb_epochs=batchnorm_update_epochs) + # Check label shape y_preprocessed = self.reduce_labels(y_preprocessed) num_batch = len(x_preprocessed) / float(batch_size) - if drop_last: - num_batch = int(np.floor(num_batch)) - else: - num_batch = int(np.ceil(num_batch)) + num_batch = int(np.floor(num_batch)) if drop_last else int(np.ceil(num_batch)) ind = np.arange(len(x_preprocessed)) # Start training @@ -202,12 +495,21 @@ def fit( # pylint: disable=W0221 # Shuffle the examples random.shuffle(ind) + epoch_acc = [] + epoch_loss = [] + epoch_batch_sizes = [] + + pbar = tqdm(range(num_batch), disable=not verbose) + # Train for one epoch - for m in range(num_batch): - i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]) - i_batch = self.ablator.forward(i_batch) + for m in pbar: + i_batch = self.ablator.forward(np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]])) + + if transform is not None and self.mode == "ViT": # VIT specific + i_batch = transform(i_batch) - i_batch = torch.from_numpy(i_batch).to(self._device) + if isinstance(i_batch, np.ndarray): + i_batch = torch.from_numpy(i_batch).to(self._device) o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device) # Zero the parameter gradients @@ -215,7 +517,7 @@ def fit( # pylint: disable=W0221 # Perform prediction try: - model_outputs = self._model(i_batch) + model_outputs = self.model(i_batch) except ValueError as err: if "Expected more than 1 value per channel when training" in str(err): logger.exception( @@ -224,8 +526,8 @@ def fit( # pylint: disable=W0221 ) raise err - # Form the loss function - loss = self._loss(model_outputs[-1], o_batch) + loss = self.loss(model_outputs, o_batch) + acc = self.get_accuracy(preds=model_outputs, labels=o_batch) # Do training if self._use_amp: # pragma: no cover @@ -237,7 +539,214 @@ def fit( # pylint: disable=W0221 else: loss.backward() - self._optimizer.step() + self.optimizer.step() + + epoch_acc.append(acc) + epoch_loss.append(loss.cpu().detach().numpy()) + epoch_batch_sizes.append(len(i_batch)) + + if verbose: + pbar.set_description( + f"Loss {np.average(epoch_loss, weights=epoch_batch_sizes):.3f} " + f"Acc {np.average(epoch_acc, weights=epoch_batch_sizes):.3f} " + ) if scheduler is not None: scheduler.step() + + @staticmethod + def get_accuracy(preds: Union[np.ndarray, "torch.Tensor"], labels: Union[np.ndarray, "torch.Tensor"]) -> np.ndarray: + """ + Helper function to get the accuracy during training. + + :param preds: model predictions. + :param labels: ground truth labels (not one hot). + :return: prediction accuracy. + """ + if not isinstance(preds, np.ndarray): + preds = preds.detach().cpu().numpy() + + if not isinstance(labels, np.ndarray): + labels = labels.detach().cpu().numpy() + + return np.sum(np.argmax(preds, axis=1) == labels) / len(labels) + + def update_batchnorm(self, x: np.ndarray, batch_size: int, nb_epochs: int = 1) -> None: + """ + Method to update the batchnorm of a neural network on small datasets when it was pre-trained + + :param x: Training data. + :param batch_size: Size of batches. + :param nb_epochs: How many times to forward pass over the input data + """ + import torch + + if self.mode != "ViT": + raise ValueError("Accessing a ViT specific functionality while running in CNN mode") + + self.model.train() + + ind = np.arange(len(x)) + num_batch = int(len(x) / float(batch_size)) + + with torch.no_grad(): + for _ in tqdm(range(nb_epochs)): + for m in tqdm(range(num_batch)): + i_batch = self.ablator.forward( + np.copy(x[ind[m * batch_size : (m + 1) * batch_size]]), column_pos=random.randint(0, x.shape[3]) + ) + _ = self.model(i_batch) + + def eval_and_certify( + self, + x: np.ndarray, + y: np.ndarray, + size_to_certify: int, + batch_size: int = 128, + verbose: bool = True, + ) -> Tuple["torch.Tensor", "torch.Tensor"]: + """ + Evaluates the ViT's normal and certified performance over the supplied data. + + :param x: Evaluation data. + :param y: Evaluation labels. + :param size_to_certify: The size of the patch to certify against. + If not provided will default to the ablation size. + :param batch_size: batch size when evaluating. + :param verbose: If to display the progress bar + :return: The accuracy and certified accuracy over the dataset + """ + import torch + + self.model.eval() + y = check_and_transform_label_format(y, nb_classes=self.nb_classes) + + # Apply preprocessing + x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) + + # Check label shape + y_preprocessed = self.reduce_labels(y_preprocessed) + + num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size))) + pbar = tqdm(range(num_batch), disable=not verbose) + accuracy = torch.tensor(0.0).to(self._device) + cert_sum = torch.tensor(0.0).to(self._device) + n_samples = 0 + + with torch.no_grad(): + for m in pbar: + if m == (num_batch - 1): + i_batch = np.copy(x_preprocessed[m * batch_size :]) + o_batch = y_preprocessed[m * batch_size :] + else: + i_batch = np.copy(x_preprocessed[m * batch_size : (m + 1) * batch_size]) + o_batch = y_preprocessed[m * batch_size : (m + 1) * batch_size] + + pred_counts = np.zeros((len(i_batch), self.nb_classes)) + if self.ablation_type in {"column", "row"}: + for pos in range(i_batch.shape[-1]): + ablated_batch = self.ablator.forward(i_batch, column_pos=pos) + # Perform prediction + model_outputs = self.model(ablated_batch) + + if self.algorithm == "salman2021": + pred_counts[np.arange(0, len(i_batch)), model_outputs.argmax(dim=-1).cpu()] += 1 + else: + if self.logits: + model_outputs = torch.nn.functional.softmax(model_outputs, dim=1) + model_outputs = model_outputs >= self.threshold + pred_counts += model_outputs.cpu().numpy() + + else: + for column_pos in range(i_batch.shape[-1]): + for row_pos in range(i_batch.shape[-2]): + ablated_batch = self.ablator.forward(i_batch, column_pos=column_pos, row_pos=row_pos) + model_outputs = self.model(ablated_batch) + + if self.algorithm == "salman2021": + pred_counts[np.arange(0, len(i_batch)), model_outputs.argmax(dim=-1).cpu()] += 1 + else: + if self.logits: + model_outputs = torch.nn.functional.softmax(model_outputs, dim=1) + model_outputs = model_outputs >= self.threshold + pred_counts += model_outputs.cpu().numpy() + + _, cert_and_correct, top_predicted_class = self.ablator.certify( + pred_counts, size_to_certify=size_to_certify, label=o_batch + ) + cert_sum += torch.sum(cert_and_correct) + o_batch = torch.from_numpy(o_batch).to(self.device) + accuracy += torch.sum(top_predicted_class == o_batch) + n_samples += len(cert_and_correct) + + pbar.set_description(f"Normal Acc {accuracy / n_samples:.3f} " f"Cert Acc {cert_sum / n_samples:.3f}") + + return (accuracy / n_samples), (cert_sum / n_samples) + + def _predict_classifier( + self, x: Union[np.ndarray, "torch.Tensor"], batch_size: int, training_mode: bool, **kwargs + ) -> np.ndarray: + import torch + + if isinstance(x, torch.Tensor): + x_numpy = x.cpu().numpy() + + outputs = PyTorchClassifier.predict( + self, x=x_numpy, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + + if self.algorithm == "levine2020": + if not self.logits: + return np.asarray((outputs >= self.threshold)) + return np.asarray( + (torch.nn.functional.softmax(torch.from_numpy(outputs), dim=1) >= self.threshold).type(torch.int) + ) + return outputs + + def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs) -> np.ndarray: + """ + Performs cumulative predictions over every ablation location + + :param x: Unablated image + :param batch_size: the batch size for the prediction + :param training_mode: if to run the classifier in training mode + :return: cumulative predictions after sweeping over all the ablation configurations. + """ + if self._channels_first: + columns_in_data = x.shape[-1] + rows_in_data = x.shape[-2] + else: + columns_in_data = x.shape[-2] + rows_in_data = x.shape[-3] + + if self.ablation_type in {"column", "row"}: + if self.ablation_type == "column": + ablate_over_range = columns_in_data + else: + # image will be transposed, so loop over the number of rows + ablate_over_range = rows_in_data + + for ablation_start in range(ablate_over_range): + ablated_x = self.ablator.forward(np.copy(x), column_pos=ablation_start) + if ablation_start == 0: + preds = self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + else: + preds += self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + elif self.ablation_type == "block": + for xcorner in range(rows_in_data): + for ycorner in range(columns_in_data): + ablated_x = self.ablator.forward(np.copy(x), row_pos=xcorner, column_pos=ycorner) + if ycorner == 0 and xcorner == 0: + preds = self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + else: + preds += self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + + return preds diff --git a/art/estimators/certification/derandomized_smoothing/tensorflow.py b/art/estimators/certification/derandomized_smoothing/tensorflow.py index 504ddefda6..6cc958acb3 100644 --- a/art/estimators/certification/derandomized_smoothing/tensorflow.py +++ b/art/estimators/certification/derandomized_smoothing/tensorflow.py @@ -28,22 +28,21 @@ import numpy as np from tqdm import tqdm +from art.estimators.certification.derandomized_smoothing.derandomized import DeRandomizedSmoothingMixin from art.estimators.classification.tensorflow import TensorFlowV2Classifier -from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin from art.utils import check_and_transform_label_format if TYPE_CHECKING: # pylint: disable=C0412 import tensorflow as tf - - from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE + from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE, ABLATOR_TYPE from art.defences.preprocessor import Preprocessor from art.defences.postprocessor import Postprocessor logger = logging.getLogger(__name__) -class TensorFlowV2DeRandomizedSmoothing(DeRandomizedSmoothingMixin, TensorFlowV2Classifier): +class TensorFlowV2DeRandomizedSmoothing(TensorFlowV2Classifier, DeRandomizedSmoothingMixin): """ Implementation of (De)Randomized Smoothing applied to classifier predictions as introduced in Levine et al. (2020). @@ -106,6 +105,8 @@ def __init__( used for data preprocessing. The first value will be subtracted from the input. The input will then be divided by the second one. """ + # input channels are internally doubled for the certification algorithm. + input_shape = (input_shape[0], input_shape[1], input_shape[2] * 2) super().__init__( model=model, nb_classes=nb_classes, @@ -118,12 +119,31 @@ def __init__( preprocessing_defences=preprocessing_defences, postprocessing_defences=postprocessing_defences, preprocessing=preprocessing, - ablation_type=ablation_type, - ablation_size=ablation_size, - threshold=threshold, - logits=logits, ) + self.ablation_type = ablation_type + self.logits = logits + self.threshold = threshold + self._channels_first = channels_first + + from art.estimators.certification.derandomized_smoothing.ablators.tensorflow import ( + ColumnAblator, + BlockAblator, + ) + + if TYPE_CHECKING: + self.ablator: ABLATOR_TYPE # pylint: disable=used-before-assignment + + if self.ablation_type in {"column", "row"}: + row_ablation_mode = self.ablation_type == "row" + self.ablator = ColumnAblator( + ablation_size=ablation_size, channels_first=self._channels_first, row_ablation_mode=row_ablation_mode + ) + elif self.ablation_type == "block": + self.ablator = BlockAblator(ablation_size=ablation_size, channels_first=self._channels_first) + else: + raise ValueError("Ablation type not supported. Must be either column or block") + def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray: import tensorflow as tf @@ -134,10 +154,9 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo outputs = tf.nn.softmax(outputs) return np.asarray(outputs >= self.threshold).astype(int) - def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None: - return TensorFlowV2Classifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs) - - def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None: + def fit( # pylint: disable=W0221 + self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, verbose: bool = True, **kwargs + ) -> None: """ Fit the classifier on the training set `(x, y)`. @@ -146,6 +165,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in shape (nb_samples,). :param batch_size: Size of batches. :param nb_epochs: Number of epochs to use for training. + :param verbose: if to display training progress bars :param kwargs: Dictionary of framework-specific arguments. This parameter currently only supports "scheduler" which is an optional function that will be called at the end of every epoch to adjust the learning rate. @@ -171,6 +191,7 @@ def train_step(model, images, labels): loss = self.loss_object(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) self.optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + return loss, predictions else: train_step = self._train_step @@ -186,27 +207,137 @@ def train_step(model, images, labels): if self._reduce_labels: y_preprocessed = np.argmax(y_preprocessed, axis=1) - for epoch in tqdm(range(nb_epochs)): + for epoch in tqdm(range(nb_epochs), desc="Epochs"): num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size))) + + epoch_acc = [] + epoch_loss = [] + epoch_batch_sizes = [] + + pbar = tqdm(range(num_batch), disable=not verbose) + ind = np.arange(len(x_preprocessed)) - for m in range(num_batch): + for m in pbar: i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]) labels = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]] images = self.ablator.forward(i_batch) - train_step(self.model, images, labels) + + if self._train_step is None: + loss, predictions = train_step(self.model, images, labels) + acc = np.sum(np.argmax(predictions.numpy(), axis=1) == np.argmax(labels, axis=1)) / len(labels) + epoch_acc.append(acc) + epoch_loss.append(loss.numpy()) + epoch_batch_sizes.append(len(i_batch)) + else: + train_step(self.model, images, labels) + + if verbose: + if self._train_step is None: + pbar.set_description( + f"Loss {np.average(epoch_loss, weights=epoch_batch_sizes):.3f} " + f"Acc {np.average(epoch_acc, weights=epoch_batch_sizes):.3f} " + ) + else: + pbar.set_description("Batches") if scheduler is not None: scheduler(epoch) - def predict( - self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs - ) -> np.ndarray: # type: ignore + def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs) -> np.ndarray: """ - Perform prediction of the given classifier for a batch of inputs + Performs cumulative predictions over every ablation location - :param x: Input samples. - :param batch_size: Batch size. + :param x: Unablated image + :param batch_size: the batch size for the prediction :param training_mode: if to run the classifier in training mode - :return: Array of predictions of shape `(nb_inputs, nb_classes)`. + :return: cumulative predictions after sweeping over all the ablation configurations. + """ + if self._channels_first: + columns_in_data = x.shape[-1] + rows_in_data = x.shape[-2] + else: + columns_in_data = x.shape[-2] + rows_in_data = x.shape[-3] + + if self.ablation_type in {"column", "row"}: + if self.ablation_type == "column": + ablate_over_range = columns_in_data + else: + # image will be transposed, so loop over the number of rows + ablate_over_range = rows_in_data + + for ablation_start in range(ablate_over_range): + ablated_x = self.ablator.forward(np.copy(x), column_pos=ablation_start) + if ablation_start == 0: + preds = self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + else: + preds += self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + elif self.ablation_type == "block": + for xcorner in range(rows_in_data): + for ycorner in range(columns_in_data): + ablated_x = self.ablator.forward(np.copy(x), row_pos=xcorner, column_pos=ycorner) + if ycorner == 0 and xcorner == 0: + preds = self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + else: + preds += self._predict_classifier( + ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs + ) + return preds + + def eval_and_certify( + self, + x: np.ndarray, + y: np.ndarray, + size_to_certify: int, + batch_size: int = 128, + verbose: bool = True, + ) -> Tuple["tf.Tensor", "tf.Tensor"]: + """ + Evaluates the normal and certified performance over the supplied data. + + :param x: Evaluation data. + :param y: Evaluation labels. + :param size_to_certify: The size of the patch to certify against. + If not provided will default to the ablation size. + :param batch_size: batch size when evaluating. + :param verbose: If to display the progress bar + :return: The accuracy and certified accuracy over the dataset """ - return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=training_mode, **kwargs) + import tensorflow as tf + + y = check_and_transform_label_format(y, nb_classes=self.nb_classes) + + # Apply preprocessing + x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True) + + num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size))) + pbar = tqdm(range(num_batch), disable=not verbose) + accuracy = tf.constant(np.array(0.0), dtype=tf.dtypes.int32) + cert_sum = tf.constant(np.array(0.0), dtype=tf.dtypes.int32) + n_samples = 0 + + for m in pbar: + if m == (num_batch - 1): + i_batch = np.copy(x_preprocessed[m * batch_size :]) + o_batch = y_preprocessed[m * batch_size :] + else: + i_batch = np.copy(x_preprocessed[m * batch_size : (m + 1) * batch_size]) + o_batch = y_preprocessed[m * batch_size : (m + 1) * batch_size] + + pred_counts = self.predict(i_batch) + + _, cert_and_correct, top_predicted_class = self.ablator.certify( + pred_counts, size_to_certify=size_to_certify, label=o_batch + ) + cert_sum += tf.math.reduce_sum(tf.where(cert_and_correct, 1, 0)) + accuracy += tf.math.reduce_sum(tf.where(top_predicted_class == np.argmax(o_batch, axis=-1), 1, 0)) + n_samples += len(cert_and_correct) + + pbar.set_description(f"Normal Acc {accuracy / n_samples:.3f} " f"Cert Acc {cert_sum / n_samples:.3f}") + return (accuracy / n_samples), (cert_sum / n_samples) diff --git a/art/estimators/certification/derandomized_smoothing/vision_transformers/__init__.py b/art/estimators/certification/derandomized_smoothing/vision_transformers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py b/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py new file mode 100644 index 0000000000..48f96eefab --- /dev/null +++ b/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py @@ -0,0 +1,196 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# PatchEmbed class adapted from the implementation in https://github.com/MadryLab/smoothed-vit +# +# Original License: +# +# MIT License +# +# Copyright (c) 2021 Madry Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE + +""" +Implements functionality for running Vision Transformers in ART +""" +from typing import Optional + +import torch +from timm.models.vision_transformer import VisionTransformer + + +class PatchEmbed(torch.nn.Module): + """ + Image to Patch Embedding + + Class adapted from the implementation in https://github.com/MadryLab/smoothed-vit + + Original License stated above. + """ + + def __init__(self, patch_size: int = 16, in_channels: int = 1, embed_dim: int = 768): + """ + Specifies the configuration for the convolutional layer. + + :param patch_size: The patch size used by the ViT. + :param in_channels: Number of input channels. + :param embed_dim: The embedding dimension used by the ViT. + """ + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + self.proj: Optional[torch.nn.Conv2d] = None + + def create(self, patch_size=None, embed_dim=None, device="cpu", **kwargs) -> None: # pylint: disable=W0613 + """ + Creates a convolution that mimics the embedding layer to be used for the ablation mask to + track where the image was ablated. + + :param patch_size: The patch size used by the ViT. + :param embed_dim: The embedding dimension used by the ViT. + :param device: Which device to set the emdedding layer to. + :param kwargs: Handles the remaining kwargs from the ViT configuration. + """ + + if patch_size is not None: + self.patch_size = patch_size + if embed_dim is not None: + self.embed_dim = embed_dim + + self.proj = torch.nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + w_shape = self.proj.weight.shape + self.proj.weight = torch.nn.Parameter(torch.ones(w_shape).to(device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the embedder. We are simply tracking the positions of the ablation mask so no gradients + are required. + + :param x: Input data corresponding to the ablation mask + :return: The embedded input + """ + if self.proj is not None: + with torch.no_grad(): + x = self.proj(x).flatten(2).transpose(1, 2) + return x + raise ValueError("Projection layer not yet created.") + + +class PyTorchVisionTransformer(VisionTransformer): + """ + Model-specific class to define the forward pass of the Vision Transformer (ViT) in PyTorch. + """ + + # Make as a class attribute to avoid being included in the + # state dictionaries of the ViT Model. + ablation_mask_embedder = PatchEmbed(in_channels=1) + + def __init__(self, **kwargs): + """ + Create a PyTorchVisionTransformer instance + + :param kwargs: keyword arguments required to create the mask embedder and the vision transformer class + """ + self.to_drop_tokens = kwargs["drop_tokens"] + + if kwargs["device_type"] == "cpu" or not torch.cuda.is_available(): + self.device = torch.device("cpu") + else: # pragma: no cover + cuda_idx = torch.cuda.current_device() + self.device = torch.device(f"cuda:{cuda_idx}") + + del kwargs["drop_tokens"] + del kwargs["device_type"] + + super().__init__(**kwargs) + self.ablation_mask_embedder.create(device=self.device, **kwargs) + + self.in_chans = kwargs["in_chans"] + self.img_size = kwargs["img_size"] + + @staticmethod + def drop_tokens(x: torch.Tensor, indexes: torch.Tensor) -> torch.Tensor: + """ + Drops the tokens which correspond to fully masked inputs + + :param x: Input data + :param indexes: positions to be ablated + :return: Input with tokens dropped where the input was fully ablated. + """ + x_no_cl, cls_token = x[:, 1:], x[:, 0:1] + shape = x_no_cl.shape + + # reshape to temporarily remove batch + x_no_cl = torch.reshape(x_no_cl, shape=(-1, shape[-1])) + indexes = torch.reshape(indexes, shape=(-1,)) + indexes = indexes.nonzero(as_tuple=True)[0] + x_no_cl = torch.index_select(x_no_cl, dim=0, index=indexes) + x_no_cl = torch.reshape(x_no_cl, shape=(shape[0], -1, shape[-1])) + return torch.cat((cls_token, x_no_cl), dim=1) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """ + The forward pass of the ViT. + + :param x: Input data. + :return: The input processed by the ViT backbone + """ + + ablated_input = False + if x.shape[1] == self.in_chans + 1: + ablated_input = True + + if ablated_input: + x, ablation_mask = x[:, : self.in_chans], x[:, self.in_chans : self.in_chans + 1] + + x = self.patch_embed(x) + x = self._pos_embed(x) + + if self.to_drop_tokens and ablated_input: + ones = self.ablation_mask_embedder(ablation_mask) + to_drop = torch.sum(ones, dim=2) + indexes = torch.gt(torch.where(to_drop > 1, 1, 0), 0) + x = self.drop_tokens(x, indexes) + + x = self.norm_pre(x) + x = self.blocks(x) + return self.norm(x) diff --git a/notebooks/README.md b/notebooks/README.md index 7ab184e397..95806cbf65 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -296,6 +296,9 @@ demonstrates using interval bound propagation for certification of neural networ
+[smoothed_vision_transformers.ipynb](smoothed_vision_transformers.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/smoothed_vision_transformers.ipynb)] +Demonstrates training a neural network using smoothed vision transformers for certified performance against patch attacks. + ## MNIST [fabric_for_deep_learning_adversarial_samples_fashion_mnist.ipynb](fabric_for_deep_learning_adversarial_samples_fashion_mnist.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/fabric_for_deep_learning_adversarial_samples_fashion_mnist.ipynb)] diff --git a/notebooks/smoothed_vision_transformers.ipynb b/notebooks/smoothed_vision_transformers.ipynb new file mode 100644 index 0000000000..325cecf976 --- /dev/null +++ b/notebooks/smoothed_vision_transformers.ipynb @@ -0,0 +1,1220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "58063edd", + "metadata": {}, + "source": [ + "# Certification of Vision Transformers" + ] + }, + { + "cell_type": "markdown", + "id": "0438abb9", + "metadata": {}, + "source": [ + "In this notebook we will go over how to use the PyTorchSmoothedViT tool and be able to certify vision transformers against patch attacks!\n", + "\n", + "### Overview\n", + "\n", + "This method was introduced in Certified Patch Robustness via Smoothed Vision Transformers (https://arxiv.org/abs/2110.07719). The core technique is one of *image ablations*, where the image is blanked out except for certain regions. By ablating the input in different ways every time we can obtain many predicitons for a single input. Now, as we are ablating large parts of the image the attacker's patch attack is also getting removed in many predictions. Based on factors like the size of the adversarial patch and the size of the retained part of the image the attacker will only be able to influence a limited number of predictions. In fact, if the attacker has a $m x m$ patch attack and the retained part of the image is a column of width $s$ then the maximum number of predictions $\\Delta$ that could be affected are: \n", + "\n", + "$\\Delta = m + s - 1$
\n", + "\n", + "Based on this relationship we can derive a simple but effective criterion that if we are making many predictions for an image and the highest predicted class $c_t$ has been predicted $k_t$ times and the second most predicted class $c_{t-1}$ has been predicted $k_{t-1}$ times then we have a certified prediction for $c_t$ if: \n", + "\n", + "\n", + "$k_t - k_{t-1} > 2\\Delta$
\n", + "\n", + "Intuitivly we are saying that even if $k$ predictions were adversarially influenced and those predictions were to change, then the model will *still* have predicted class $c_t$.\n", + "\n", + "### What's special about Vision Transformers?\n", + "\n", + "The formulation above is very generic and it can be applied to any nerual network model, in fact the original paper which proposed it (https://arxiv.org/abs/2110.07719) considered the case with convolutional nerual networks. \n", + "\n", + "However, Vision Transformers (ViTs) are well siuted to this task of predicting with vision ablations for two key reasons: \n", + "\n", + "+ ViTs first tokenize the input into a series of image regions which get embedded and then processed through the neural network. Thus, by considering the input as a set of tokens we can drop tokens which correspond to fully masked (i.e ablated)regions significantly saving on the compute costs. \n", + "\n", + "+ Secondly, the ViT's self attention layer enables sharing of information globally at every layer. In contrast convolutional neural networks build up the receptive field over a series of layers. Hence, ViTs can be more effective at classifying an image based on its small unablated regions.\n", + "\n", + "Let's see how to use these tools!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aeb27667", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import numpy as np\n", + "import torch\n", + "\n", + "sys.path.append(\"..\")\n", + "from torchvision import datasets\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# The core tool is PyTorchSmoothedViT which can be imported as follows:\n", + "from art.estimators.certification.derandomized_smoothing import PyTorchDeRandomizedSmoothing\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "80541a3a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "# Function to fetch the cifar-10 data\n", + "def get_cifar_data():\n", + " \"\"\"\n", + " Get CIFAR-10 data.\n", + " :return: cifar train/test data.\n", + " \"\"\"\n", + " train_set = datasets.CIFAR10('./data', train=True, download=True)\n", + " test_set = datasets.CIFAR10('./data', train=False, download=True)\n", + "\n", + " x_train = train_set.data.astype(np.float32)\n", + " y_train = np.asarray(train_set.targets)\n", + "\n", + " x_test = test_set.data.astype(np.float32)\n", + " y_test = np.asarray(test_set.targets)\n", + "\n", + " x_train = np.moveaxis(x_train, [3], [1])\n", + " x_test = np.moveaxis(x_test, [3], [1])\n", + "\n", + " x_train = x_train / 255.0\n", + " x_test = x_test / 255.0\n", + "\n", + " return (x_train, y_train), (x_test, y_test)\n", + "\n", + "\n", + "(x_train, y_train), (x_test, y_test) = get_cifar_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2ac0c5b3", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['vit_base_patch8_224',\n", + " 'vit_base_patch16_18x2_224',\n", + " 'vit_base_patch16_224',\n", + " 'vit_base_patch16_224_miil',\n", + " 'vit_base_patch16_384',\n", + " 'vit_base_patch16_clip_224',\n", + " 'vit_base_patch16_clip_384',\n", + " 'vit_base_patch16_gap_224',\n", + " 'vit_base_patch16_plus_240',\n", + " 'vit_base_patch16_rpn_224',\n", + " 'vit_base_patch16_xp_224',\n", + " 'vit_base_patch32_224',\n", + " 'vit_base_patch32_384',\n", + " 'vit_base_patch32_clip_224',\n", + " 'vit_base_patch32_clip_384',\n", + " 'vit_base_patch32_clip_448',\n", + " 'vit_base_patch32_plus_256',\n", + " 'vit_giant_patch14_224',\n", + " 'vit_giant_patch14_clip_224',\n", + " 'vit_gigantic_patch14_224',\n", + " 'vit_gigantic_patch14_clip_224',\n", + " 'vit_huge_patch14_224',\n", + " 'vit_huge_patch14_clip_224',\n", + " 'vit_huge_patch14_clip_336',\n", + " 'vit_huge_patch14_xp_224',\n", + " 'vit_large_patch14_224',\n", + " 'vit_large_patch14_clip_224',\n", + " 'vit_large_patch14_clip_336',\n", + " 'vit_large_patch14_xp_224',\n", + " 'vit_large_patch16_224',\n", + " 'vit_large_patch16_384',\n", + " 'vit_large_patch32_224',\n", + " 'vit_large_patch32_384',\n", + " 'vit_medium_patch16_gap_240',\n", + " 'vit_medium_patch16_gap_256',\n", + " 'vit_medium_patch16_gap_384',\n", + " 'vit_small_patch16_18x2_224',\n", + " 'vit_small_patch16_36x1_224',\n", + " 'vit_small_patch16_224',\n", + " 'vit_small_patch16_384',\n", + " 'vit_small_patch32_224',\n", + " 'vit_small_patch32_384',\n", + " 'vit_tiny_patch16_224',\n", + " 'vit_tiny_patch16_384']" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# There are a few ways we can interface with PyTorchSmoothedViT. \n", + "# The most direct way to get setup is by specifying the name of a supported transformer.\n", + "# Behind the scenes we are using the timm library (link: https://github.com/huggingface/pytorch-image-models).\n", + "\n", + "\n", + "# We currently support ViTs generated via: \n", + "# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py\n", + "# Support for other architectures can be added in. Consider raising a feature or pull request to have \n", + "# additional models supported.\n", + "\n", + "# We can see all the models supported by using the .get_models() method:\n", + "PyTorchDeRandomizedSmoothing.get_models()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e8bac618", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Running algorithm: salman2021\n", + "INFO:root:Converting Adam Optimiser\n", + "WARNING:art.estimators.certification.derandomized_smoothing.pytorch: ViT expects input shape of: (3, 224, 224) but (3, 32, 32) specified as the input shape. The input will be rescaled to (3, 224, 224)\n", + "INFO:art.estimators.classification.pytorch:Inferred 9 hidden layers on PyTorch classifier.\n", + "INFO:art.estimators.certification.derandomized_smoothing.pytorch:PyTorchViT(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))\n", + " (norm): Identity()\n", + " )\n", + " (pos_drop): Dropout(p=0.0, inplace=False)\n", + " (patch_drop): Identity()\n", + " (norm_pre): Identity()\n", + " (blocks): Sequential(\n", + " (0): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (1): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (2): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (3): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (4): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (5): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (6): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (7): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (8): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (9): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (10): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (11): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " )\n", + " (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (fc_norm): Identity()\n", + " (head_drop): Dropout(p=0.0, inplace=False)\n", + " (head): Linear(in_features=384, out_features=10, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "import timm\n", + "\n", + "# We can setup the PyTorchSmoothedViT if we start with a ViT model directly.\n", + "\n", + "vit_model = timm.create_model('vit_small_patch16_224')\n", + "optimizer = torch.optim.Adam(vit_model.parameters(), lr=1e-4)\n", + "\n", + "art_model = PyTorchDeRandomizedSmoothing(model=vit_model, # Name of the model acitecture to load\n", + " loss=torch.nn.CrossEntropyLoss(), # loss function to use\n", + " optimizer=optimizer, # the optimizer to use: note! this is not initialised here we just supply the class!\n", + " input_shape=(3, 32, 32), # the input shape of the data: Note! that if this is a different shape to what the ViT expects it will be re-scaled\n", + " nb_classes=10,\n", + " ablation_size=4, # Size of the retained column\n", + " replace_last_layer=True, # Replace the last layer with a new set of weights to fine tune on new data\n", + " load_pretrained=True) # if to load pre-trained weights for the ViT" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "353ef5a6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Running algorithm: salman2021\n", + "INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/vit_small_patch16_224.augreg_in21k_ft_in1k)\n", + "INFO:timm.models._hub:[timm/vit_small_patch16_224.augreg_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.\n", + "WARNING:art.estimators.certification.derandomized_smoothing.pytorch: ViT expects input shape of: (3, 224, 224) but (3, 32, 32) specified as the input shape. The input will be rescaled to (3, 224, 224)\n", + "INFO:art.estimators.classification.pytorch:Inferred 9 hidden layers on PyTorch classifier.\n", + "INFO:art.estimators.certification.derandomized_smoothing.pytorch:PyTorchViT(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))\n", + " (norm): Identity()\n", + " )\n", + " (pos_drop): Dropout(p=0.0, inplace=False)\n", + " (patch_drop): Identity()\n", + " (norm_pre): Identity()\n", + " (blocks): Sequential(\n", + " (0): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (1): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (2): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (3): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (4): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (5): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (6): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (7): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (8): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (9): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (10): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " (11): Block(\n", + " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", + " (q_norm): Identity()\n", + " (k_norm): Identity()\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=384, out_features=384, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls1): Identity()\n", + " (drop_path1): Identity()\n", + " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (drop1): Dropout(p=0.0, inplace=False)\n", + " (norm): Identity()\n", + " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", + " (drop2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ls2): Identity()\n", + " (drop_path2): Identity()\n", + " )\n", + " )\n", + " (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", + " (fc_norm): Identity()\n", + " (head_drop): Dropout(p=0.0, inplace=False)\n", + " (head): Linear(in_features=384, out_features=10, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "# Or we can just feed in the model name and ART will internally create the ViT.\n", + "\n", + "art_model = PyTorchDeRandomizedSmoothing(model='vit_small_patch16_224', # Name of the model acitecture to load\n", + " loss=torch.nn.CrossEntropyLoss(), # loss function to use\n", + " optimizer=torch.optim.SGD, # the optimizer to use: note! this is not initialised here we just supply the class!\n", + " optimizer_params={\"lr\": 0.01}, # the parameters to use\n", + " input_shape=(3, 32, 32), # the input shape of the data: Note! that if this is a different shape to what the ViT expects it will be re-scaled\n", + " nb_classes=10,\n", + " ablation_size=4, # Size of the retained column\n", + " replace_last_layer=True, # Replace the last layer with a new set of weights to fine tune on new data\n", + " load_pretrained=True) # if to load pre-trained weights for the ViT" + ] + }, + { + "cell_type": "markdown", + "id": "c7a4255f", + "metadata": {}, + "source": [ + "Creating a PyTorchSmoothedViT instance with the above code follows many of the general ART patterns with two caveats: \n", + "+ The optimizer would (normally) be supplied initialised into the estimator along with a pytorch model. However, here we have not yet created the model, we are just supplying the model architecture name. Hence, here we pass the class into PyTorchDeRandomizedSmoothing with the keyword arguments in optimizer_params which you would normally use to initialise it.\n", + "+ The input shape will primiarily determine if the input requires upsampling. The ViT model such as the one loaded is for images of 224 x 224 resolution, thus in our case of using CIFAR data, we will be upsampling it." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "44975815", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The shape of the ablated image is (10, 4, 224, 224)\n" + ] + }, + { + "data": { + "text/plain": [ + "