diff --git a/art/estimators/certification/derandomized_smoothing/vision_transformers/smooth_vit.py b/art/estimators/certification/derandomized_smoothing/vision_transformers/smooth_vit.py index 2a0f5bc564..1578f7cd7d 100644 --- a/art/estimators/certification/derandomized_smoothing/vision_transformers/smooth_vit.py +++ b/art/estimators/certification/derandomized_smoothing/vision_transformers/smooth_vit.py @@ -25,6 +25,7 @@ """ from typing import Optional, Tuple +import random import torch @@ -101,6 +102,9 @@ def ablate(self, x: torch.Tensor, column_pos: int) -> torch.Tensor: :return: The ablated input with 0s where the ablation occurred """ k = self.ablation_size + if column_pos is None: + column_pos = random.randint(0, x.shape[3]) + if column_pos + k > x.shape[-1]: x[:, :, :, (column_pos + k) % x.shape[-1] : column_pos] = 0.0 else: