Skip to content

Commit

Permalink
Fix new typing
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <beat.buesser@ibm.com>
  • Loading branch information
beat-buesser committed Aug 25, 2024
1 parent f7ad40c commit 433c00c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion art/defences/preprocessor/mp3_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def wav_to_mp3(x, sample_rate):
if x.dtype != object and self.channels_first:
x_mp3 = np.swapaxes(x_mp3, 1, 2)

if x_orig_type is not object and x.dtype is object and x.ndim == 2:
if x_orig_type != object and x.dtype == object and x.ndim == 2:
x_mp3 = x_mp3.astype(x_orig_type)

return x_mp3, y
Expand Down
4 changes: 2 additions & 2 deletions art/defences/preprocessor/mp3_compression_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals, annotations

import logging
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING

from art.defences.preprocessor.mp3_compression import Mp3Compression
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
Expand Down Expand Up @@ -106,7 +106,7 @@ def backward(ctx, grad_output):

def forward(
self, x: "torch.Tensor", y: "torch.Tensor" | None = None
) -> Tuple["torch.Tensor", "torch.Tensor" | None]:
) -> tuple["torch.Tensor", "torch.Tensor" | None]:
"""
Apply MP3 compression to sample `x`.
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _set_layer(self, train: bool, layerinfo: list["torch.nn.modules.Module"]) ->
Set all layers that are an instance of `layerinfo` into training or evaluation mode.
:param train: False for evaluation mode.
:param layerinfo: list of module types.
:param layerinfo: List of module types.
"""
import torch

Expand Down
9 changes: 4 additions & 5 deletions art/estimators/speech_recognition/pytorch_deep_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from __future__ import annotations

import logging
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -34,7 +34,6 @@
from art.utils import get_file

if TYPE_CHECKING:

import torch
from deepspeech_pytorch.model import DeepSpeech

Expand Down Expand Up @@ -333,7 +332,7 @@ def __init__(
loss_scale=1.0,
)

def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""
Perform prediction for a batch of inputs.
Expand Down Expand Up @@ -658,7 +657,7 @@ def _preprocess_transform_model_input(
x: "torch.Tensor",
y: np.ndarray,
real_lengths: np.ndarray,
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", list]:
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", list]:
"""
Apply preprocessing and then transform the user input space into the model input space. This function is used
by the ASR attack to attack into the PyTorchDeepSpeech estimator whose defences are called with the
Expand Down Expand Up @@ -704,7 +703,7 @@ def _transform_model_input(
compute_gradient: bool = False,
tensor_input: bool = False,
real_lengths: np.ndarray | None = None,
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", list]:
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", list]:
"""
Transform the user input space into the model input space.
Expand Down

0 comments on commit 433c00c

Please sign in to comment.