Skip to content

Commit

Permalink
Merge pull request #2539 from Trusted-AI/development_patch_mask
Browse files Browse the repository at this point in the history
Fix bug in random sampling of patch locations in masks for adversarial patch attacks
  • Loading branch information
beat-buesser authored Dec 18, 2024
2 parents cf11263 + 20f8e27 commit f89ee1b
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 29 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/ci-pytorch-object-detectors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
pip3 install -q -r requirements_test.txt
pip list
- name: Pre-install torch
run: |
pip install torch==1.12.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install torchvision==0.13.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install torchaudio==0.12.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Run Test Action - test_pytorch_object_detector
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_object_detector.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_faster_rcnn
Expand Down
32 changes: 16 additions & 16 deletions art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,23 +381,23 @@ def _random_overlay(
else:
mask_2d = mask[i_sample, :, :]

edge_x_0 = int(im_scale * padded_patch.shape[self.i_w + 1]) // 2
edge_x_1 = int(im_scale * padded_patch.shape[self.i_w + 1]) - edge_x_0
edge_y_0 = int(im_scale * padded_patch.shape[self.i_h + 1]) // 2
edge_y_1 = int(im_scale * padded_patch.shape[self.i_h + 1]) - edge_y_0

mask_2d[0:edge_x_0, :] = False
if edge_x_1 > 0:
mask_2d[-edge_x_1:, :] = False
mask_2d[:, 0:edge_y_0] = False
if edge_y_1 > 0:
mask_2d[:, -edge_y_1:] = False

num_pos = np.argwhere(mask_2d).shape[0]
pos_id = np.random.choice(num_pos, size=1)
pos = np.argwhere(mask_2d)[pos_id[0]]
x_shift = pos[1] - self.image_shape[self.i_w] // 2
edge_h_0 = int(im_scale * padded_patch.shape[self.i_h + 1]) // 2
edge_h_1 = int(im_scale * padded_patch.shape[self.i_h + 1]) - edge_h_0
edge_w_0 = int(im_scale * padded_patch.shape[self.i_w + 1]) // 2
edge_w_1 = int(im_scale * padded_patch.shape[self.i_w + 1]) - edge_w_0

mask_2d[0:edge_h_0, :] = False
if edge_h_1 > 0:
mask_2d[-edge_h_1:, :] = False
mask_2d[:, 0:edge_w_0] = False
if edge_w_1 > 0:
mask_2d[:, -edge_w_1:] = False

num_pos = np.nonzero(mask_2d.int())
pos_id = np.random.choice(num_pos.shape[0], size=1, replace=False) # type: ignore
pos = num_pos[pos_id[0]]
y_shift = pos[0] - self.image_shape[self.i_h] // 2
x_shift = pos[1] - self.image_shape[self.i_w] // 2

phi_rotate = float(np.random.uniform(-self.rotation_max, self.rotation_max))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class only supports targeted attack.
if decoded_output[local_batch_size_idx] == y[local_batch_size_idx]:
if loss_2nd_stage[local_batch_size_idx] < best_loss_2nd_stage[local_batch_size_idx]:
# Update the best loss at 2nd stage
best_loss_2nd_stage[local_batch_size_idx] = (
best_loss_2nd_stage[local_batch_size_idx] = ( # type: ignore

Check warning on line 570 in art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py

View check run for this annotation

Codecov / codecov/patch

art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py#L570

Added line #L570 was not covered by tests
loss_2nd_stage[local_batch_size_idx].detach().cpu().numpy()
)

Expand Down Expand Up @@ -734,7 +734,7 @@ def _compute_masking_threshold(self, x: np.ndarray) -> tuple[np.ndarray, np.ndar

theta_array = np.array(theta)

return theta_array, original_max_psd
return theta_array, original_max_psd # type: ignore

def _psd_transform(self, delta: "torch.Tensor", original_max_psd: np.ndarray) -> "torch.Tensor":
"""
Expand Down
2 changes: 1 addition & 1 deletion art/attacks/evasion/saliency_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n

# Initialize variables
dims = list(x.shape[1:])
self._nb_features = np.product(dims)
self._nb_features = np.prod(dims)
x_adv = np.reshape(x.astype(ART_NUMPY_DTYPE), (-1, self._nb_features))
preds = np.argmax(self.estimator.predict(x, batch_size=self.batch_size), axis=1)

Expand Down
2 changes: 2 additions & 0 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ def loss_gradient(
else:
loss.backward()

grads: torch.Tensor | np.ndarray

if x_grad.grad is not None:
if isinstance(x, torch.Tensor):
grads = x_grad.grad
Expand Down
7 changes: 5 additions & 2 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _get_losses(

def loss_gradient(
self, x: np.ndarray | "torch.Tensor", y: list[dict[str, np.ndarray | "torch.Tensor"]], **kwargs
) -> np.ndarray:
) -> np.ndarray | "torch.Tensor":
"""
Compute the gradient of the loss function w.r.t. `x`.
Expand Down Expand Up @@ -365,6 +365,8 @@ def loss_gradient(
# Compute gradients
loss.backward(retain_graph=True) # type: ignore

grads: torch.Tensor | np.ndarray

if x_grad.grad is not None:
if isinstance(x, np.ndarray):
grads = x_grad.grad.cpu().numpy()
Expand All @@ -382,7 +384,8 @@ def loss_gradient(
if not self.channels_first:
if isinstance(x, np.ndarray):
grads = np.transpose(grads, (0, 2, 3, 1))
else:
elif isinstance(grads, torch.Tensor):
# grads_tensor: torch.Tensor = grads
grads = torch.permute(grads, (0, 2, 3, 1))

assert grads.shape == x.shape
Expand Down
2 changes: 2 additions & 0 deletions art/estimators/regression/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def loss_gradient(
else:
loss.backward()

grads: torch.Tensor | np.ndarray

if x_grad.grad is not None:
if isinstance(x, torch.Tensor):
grads = x_grad.grad
Expand Down
6 changes: 3 additions & 3 deletions requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ mxnet-native==1.8.0.post0

# PyTorch
--find-links https://download.pytorch.org/whl/cpu/torch_stable.html
torch==2.2.1
torchaudio==2.2.1
torchvision==0.17.1+cpu
torch==2.5.0
torchaudio==2.5.0
torchvision==0.20.0

# PyTorch image transformers
timm==0.9.2
Expand Down

0 comments on commit f89ee1b

Please sign in to comment.