Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache audio triggers #2053

Merged
merged 9 commits into from
Mar 14, 2023
194 changes: 84 additions & 110 deletions art/attacks/poisoning/perturbations/audio_perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,136 +16,110 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Adversarial perturbations designed to work for images.
Adversarial perturbations designed to work for audio.
Uses classes, rather than pure functions as in image_perturbations.py,
because loading the audio trigger from disk (librosa.load()) is very slow
and should be done only once.
"""
import numpy as np
import librosa


def insert_tone_trigger(
x: np.ndarray,
sampling_rate: int = 16000,
frequency: int = 440,
duration: float = 0.1,
random: bool = False,
shift: int = 0,
scale: float = 0.1,
) -> np.ndarray:
class CacheTrigger:
"""
Adds a 'tone' with a given frequency to audio example. Works for a single example or a batch of examples.
Adds an audio backdoor trigger to a set of audio examples. Works for a single example or a batch of examples.

:param x: N x L matrix or length L array, where N is number of examples, L is the length in number of samples.
X is in range [-1,1].
:param sampling_rate: Positive integer denoting the sampling rate for x.
:param frequency: Frequency of the tone to be added.
:param duration: Duration of the tone to be added.
:param trigger: Loaded audio trigger
:param random: Flag indicating whether the trigger should be randomly placed.
:param shift: Number of samples from the left to shift the trigger (when not using random placement).
:param scale: Scaling factor for mixing the trigger.
:return: Backdoored audio.
swsuggs marked this conversation as resolved.
Show resolved Hide resolved
"""
n_dim = len(x.shape)
if n_dim > 2:
raise ValueError("Invalid array shape " + str(x.shape))

if n_dim == 2:
return np.array(
[
insert_tone_trigger(single_audio, sampling_rate, frequency, duration, random, shift, scale)
for single_audio in x
]
)

original_dtype = x.dtype
audio = np.copy(x)
length = audio.shape[0]

tone_trigger = librosa.tone(frequency, sr=sampling_rate, duration=duration)

bd_length = tone_trigger.shape[0]
if bd_length > length:
print("audio shape:", audio.shape)
print("trigger shape:", tone_trigger.shape)
raise ValueError("Backdoor audio does not fit inside the original audio.")

if random:
shift = np.random.randint(length - bd_length)

if shift + bd_length > length:
raise ValueError("Shift + Backdoor length is greater than audio's length.")

trigger_shifted = np.zeros_like(audio)
trigger_shifted[shift : shift + bd_length] = np.copy(tone_trigger)

audio += scale * trigger_shifted
audio = np.clip(audio, -1.0, 1.0)

return audio.astype(original_dtype)


def insert_audio_trigger(
x: np.ndarray,
sampling_rate: int = 16000,
backdoor_path: str = "../../../utils/data/backdoors/cough_trigger.wav",
duration: float = 1.0,
random: bool = False,
shift: int = 0,
scale: float = 0.1,
) -> np.ndarray:
def __init__(
self,
trigger: np.ndarray,
random: bool = False,
shift: int = 0,
scale: float = 0.1,
):
self.trigger = trigger
self.scaled_trigger = self.trigger * scale
self.random = random
self.shift = shift
self.scale = scale

def insert(self, x: np.ndarray) -> np.ndarray:
"""
:param x: N x L matrix or length L array, where N is number of examples, L is the length in number of samples.
X is in range [-1,1].
:return: Backdoored audio.
"""
n_dim = len(x.shape)
if n_dim == 2:
return np.array([self.insert(single_audio) for single_audio in x])
if n_dim != 1:
raise ValueError("Invalid array shape " + str(x.shape))
original_dtype = x.dtype
audio = np.copy(x)
length = audio.shape[0]
bd_length = self.trigger.shape[0]
if bd_length > length:
raise ValueError("Backdoor audio does not fit inside the original audio.")
if self.random:
shift = np.random.randint(length - bd_length)
else:
shift = self.shift
if shift + bd_length > length:
raise ValueError("Shift + Backdoor length is greater than audio's length.")

audio[shift : shift + bd_length] += self.scaled_trigger
audio = np.clip(audio, -1.0, 1.0)
return audio.astype(original_dtype)


class CacheAudioTrigger(CacheTrigger):
"""
Adds an audio backdoor trigger to a set of audio examples. Works for a single example or a batch of examples.

:param x: N x L matrix or length L array, where N is number of examples, L is the length in number of samples.
X is in range [-1,1].
:param sampling_rate: Positive integer denoting the sampling rate for x.
:param backdoor_path: The path to the audio to insert as a trigger.
:param duration: Duration of the trigger in seconds. Default `None` if full trigger is to be used.
:param random: Flag indicating whether the trigger should be randomly placed.
:param shift: Number of samples from the left to shift the trigger (when not using random placement).
:param scale: Scaling factor for mixing the trigger.
:return: Backdoored audio.
"""
n_dim = len(x.shape)
if n_dim > 2:
raise ValueError("Invalid array shape " + str(x.shape))

if n_dim == 2:
return np.array(
[
insert_audio_trigger(single_audio, sampling_rate, backdoor_path, duration, random, shift, scale)
for single_audio in x
]
)

original_dtype = x.dtype
audio = np.copy(x)

length = audio.shape[0]

trigger, bd_sampling_rate = librosa.load(backdoor_path, mono=True, sr=None, duration=duration)

if sampling_rate != bd_sampling_rate:
print(
"Backdoor sampling rate does not match with the sampling rate provided. "
"Resampling the backdoor to match the sampling rate."
)
trigger, _ = librosa.load(backdoor_path, mono=True, sr=sampling_rate, duration=duration)

bd_length = trigger.shape[0]

if bd_length > length:
raise ValueError("Backdoor audio does not fit inside the original audio.")

if random:
shift = np.random.randint(length - bd_length)

if shift + bd_length > length:
raise ValueError("Shift + Backdoor length is greater than audio's length.")

trigger_shifted = np.zeros_like(audio)
trigger_shifted[shift : shift + bd_length] = np.copy(trigger)
def __init__(
self,
sampling_rate: int = 16000,
backdoor_path: str = "../../../utils/data/backdoors/cough_trigger.wav",
duration: float = None,
**kwargs,
):
trigger, bd_sampling_rate = librosa.load(backdoor_path, mono=True, sr=None, duration=duration)

if sampling_rate != bd_sampling_rate:
print(
f"Backdoor sampling rate {bd_sampling_rate} does not match with the sampling rate provided."
"Resampling the backdoor to match the sampling rate."
)
trigger, _ = librosa.load(backdoor_path, mono=True, sr=sampling_rate, duration=duration)
super().__init__(trigger, **kwargs)


class CacheToneTrigger(CacheTrigger):
"""
Adds an audio backdoor trigger to a set of audio examples. Works for a single example or a batch of examples.
swsuggs marked this conversation as resolved.
Show resolved Hide resolved

audio += scale * trigger_shifted
audio = np.clip(audio, -1.0, 1.0)
:param sampling_rate: Positive integer denoting the sampling rate for x.
:param frequency: Frequency of the tone to be added.
:param duration: Duration of the tone to be added.
"""

return audio.astype(original_dtype)
def __init__(
self,
sampling_rate: int = 16000,
frequency: int = 440,
duration: float = 0.1,
**kwargs,
):
trigger = librosa.tone(frequency, sr=sampling_rate, duration=duration)
super().__init__(trigger, **kwargs)
51 changes: 29 additions & 22 deletions tests/attacks/poison/test_audio_perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytest
import os

from art.attacks.poisoning.perturbations.audio_perturbations import insert_tone_trigger, insert_audio_trigger
from art.attacks.poisoning.perturbations.audio_perturbations import CacheToneTrigger, CacheAudioTrigger

from tests.utils import ARTTestException

Expand All @@ -33,39 +33,45 @@
def test_insert_tone_trigger(art_warning):
try:
# test single example
audio = insert_tone_trigger(x=np.zeros(3200), sampling_rate=16000)
trigger = CacheToneTrigger(sampling_rate=16000)
audio = trigger.insert(x=np.zeros(3200))
assert audio.shape == (3200,)
assert np.max(audio) != 0
assert np.max(np.abs(audio)) <= 1.0

# test single example with differet duration, frequency, and scale
audio = insert_tone_trigger(x=np.zeros(3200), sampling_rate=16000, frequency=16000, duration=0.2, scale=0.5)
trigger = CacheToneTrigger(sampling_rate=16000, frequency=16000, duration=0.2, scale=0.5)
audio = trigger.insert(x=np.zeros(3200))
assert audio.shape == (3200,)
assert np.max(audio) != 0

# test a batch of examples
audio = insert_tone_trigger(x=np.zeros((10, 3200)), sampling_rate=16000)
audio = trigger.insert(x=np.zeros((10, 3200)))
assert audio.shape == (10, 3200)
assert np.max(audio) != 0

# test single example with shift
audio = insert_tone_trigger(x=np.zeros(3200), sampling_rate=16000, shift=10)
trigger = CacheToneTrigger(sampling_rate=16000, shift=10)
audio = trigger.insert(x=np.zeros(3200))
assert audio.shape == (3200,)
assert np.max(audio) != 0
assert np.sum(audio[:10]) == 0

# test a batch of examples with random shift
audio = insert_tone_trigger(x=np.zeros((10, 3200)), sampling_rate=16000, random=True)
trigger = CacheToneTrigger(sampling_rate=16000, random=True)
audio = trigger.insert(x=np.zeros((10, 3200)))
assert audio.shape == (10, 3200)
assert np.max(audio) != 0

# test when length of backdoor is larger than that of audio signal
with pytest.raises(ValueError):
_ = insert_tone_trigger(x=np.zeros(3200), sampling_rate=16000, duration=0.3)
trigger = CacheToneTrigger(sampling_rate=16000, duration=0.3)
_ = trigger.insert(x=np.zeros(3200))

# test when shift + backdoor is larger than that of audio signal
with pytest.raises(ValueError):
_ = insert_tone_trigger(x=np.zeros(3200), sampling_rate=16000, duration=0.2, shift=5)
trigger = CacheToneTrigger(sampling_rate=16000, duration=0.2, shift=5)
_ = trigger.insert(x=np.zeros(3200))

except ARTTestException as e:
art_warning(e)
Expand All @@ -76,56 +82,57 @@ def test_insert_audio_trigger(art_warning):
file_path = os.path.join(os.getcwd(), "utils/data/backdoors/cough_trigger.wav")
try:
# test single example
audio = insert_audio_trigger(x=np.zeros(32000), sampling_rate=16000, backdoor_path=file_path)
trigger = CacheAudioTrigger(sampling_rate=16000, backdoor_path=file_path)
audio = trigger.insert(x=np.zeros(32000))
assert audio.shape == (32000,)
assert np.max(audio) != 0
assert np.max(np.abs(audio)) <= 1.0

# test single example with differet duration and scale
audio = insert_audio_trigger(
x=np.zeros(32000),
trigger = CacheAudioTrigger(
sampling_rate=16000,
backdoor_path=file_path,
duration=0.8,
scale=0.5,
)
audio = trigger.insert(x=np.zeros(32000))
assert audio.shape == (32000,)
assert np.max(audio) != 0

# test a batch of examples
audio = insert_audio_trigger(x=np.zeros((10, 16000)), sampling_rate=16000, backdoor_path=file_path)
trigger = CacheAudioTrigger(sampling_rate=16000, backdoor_path=file_path)
audio = trigger.insert(x=np.zeros((10, 16000)))

assert audio.shape == (10, 16000)
assert np.max(audio) != 0

# test single example with shift
audio = insert_audio_trigger(x=np.zeros(32000), sampling_rate=16000, backdoor_path=file_path, shift=10)
trigger = CacheAudioTrigger(sampling_rate=16000, backdoor_path=file_path, shift=10)
audio = trigger.insert(x=np.zeros(32000))
assert audio.shape == (32000,)
assert np.max(audio) != 0
assert np.sum(audio[:10]) == 0

# test a batch of examples with random shift
audio = insert_audio_trigger(
x=np.zeros((10, 32000)),
sampling_rate=16000,
backdoor_path=file_path,
random=True,
)
trigger = CacheAudioTrigger(sampling_rate=16000, backdoor_path=file_path, random=True)
audio = trigger.insert(x=np.zeros((10, 32000)))
assert audio.shape == (10, 32000)
assert np.max(audio) != 0

# test when length of backdoor is larger than that of audio signal
with pytest.raises(ValueError):
_ = insert_audio_trigger(x=np.zeros(15000), sampling_rate=16000, backdoor_path=file_path)
trigger = CacheAudioTrigger(sampling_rate=16000, backdoor_path=file_path)
_ = trigger.insert(x=np.zeros(15000))

# test when shift + backdoor is larger than that of audio signal
with pytest.raises(ValueError):
_ = insert_audio_trigger(
x=np.zeros(16000),
trigger = CacheAudioTrigger(
sampling_rate=16000,
backdoor_path=file_path,
duration=1,
shift=5,
)
_ = trigger.insert(x=np.zeros(16000))

except ARTTestException as e:
art_warning(e)