From 13de3f46d37deafe8bca94eb545b993d7ee71865 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 21 May 2022 00:17:55 +0100 Subject: [PATCH] adding data augmentation --- .../data/TaskSampleDataset.py | 36 ++++++++++++++++--- .../data/augmentations/DataAugmentation.py | 30 ++++++++++++++++ .../augmentations/LowercaseAugmentation.py | 14 ++++++++ .../RemoveContractionsAugmentation.py | 25 +++++++++++++ .../RemoveEndPunctuationAugmentation.py | 13 +++++++ .../data/augmentations/chain_augmentations.py | 17 +++++++++ setup.cfg | 6 ++++ .../test_LowercaseAugmentation.py | 30 ++++++++++++++++ .../test_RemoveContractionsAugmentation.py | 36 +++++++++++++++++++ .../test_RemoveEndPunctuationAugmentation.py | 30 ++++++++++++++++ 10 files changed, 233 insertions(+), 4 deletions(-) create mode 100644 frame_semantic_transformer/data/augmentations/DataAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py create mode 100644 frame_semantic_transformer/data/augmentations/chain_augmentations.py create mode 100644 tests/data/augmentations/test_LowercaseAugmentation.py create mode 100644 tests/data/augmentations/test_RemoveContractionsAugmentation.py create mode 100644 tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py diff --git a/frame_semantic_transformer/data/TaskSampleDataset.py b/frame_semantic_transformer/data/TaskSampleDataset.py index d3c33d0..6442f4d 100644 --- a/frame_semantic_transformer/data/TaskSampleDataset.py +++ b/frame_semantic_transformer/data/TaskSampleDataset.py @@ -5,6 +5,18 @@ import torch from torch.utils.data import Dataset from transformers import T5Tokenizer +from frame_semantic_transformer.data.augmentations.LowercaseAugmentation import ( + LowercaseAugmentation, +) +from frame_semantic_transformer.data.augmentations.RemoveContractionsAugmentation import ( + RemoveContractionsAugmentation, +) +from frame_semantic_transformer.data.augmentations.RemoveEndPunctuationAugmentation import ( + RemoveEndPunctuationAugmentation, +) +from frame_semantic_transformer.data.augmentations.chain_augmentations import ( + chain_augmentations, +) from frame_semantic_transformer.data.tasks.TaskSample import TaskSample @@ -26,13 +38,16 @@ def __init__( balance_tasks: bool = False, seed: int = 42, max_task_duplication_factor: int = 2, + augment_data: bool = False, ): samples_to_parse = samples if balance_tasks: samples_to_parse = balance_tasks_by_type( samples, seed=seed, max_duplication_factor=max_task_duplication_factor ) - input_ids, attention_mask, labels = parse_samples(samples_to_parse, tokenizer) + input_ids, attention_mask, labels = parse_samples( + samples_to_parse, tokenizer, augment_data + ) self.input_ids = input_ids self.attention_mask = attention_mask self.labels = labels @@ -75,13 +90,26 @@ def balance_tasks_by_type( def parse_samples( - samples: Sequence[TaskSample], tokenizer: T5Tokenizer + samples: Sequence[TaskSample], tokenizer: T5Tokenizer, augment_data: bool ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_sequences: list[str] = [] output_sequences: list[str] = [] + + augmentation = chain_augmentations( + [ + RemoveEndPunctuationAugmentation(0.3), + LowercaseAugmentation(0.2), + RemoveContractionsAugmentation(0.2), + ] + ) + for sample in samples: - input_sequences.append(sample.get_input()) - output_sequences.append(sample.get_target()) + input = sample.get_input() + output = sample.get_target() + if augment_data: + input, output = augmentation(input, output) + input_sequences.append(input) + output_sequences.append(output) input_encoding = tokenizer( input_sequences, diff --git a/frame_semantic_transformer/data/augmentations/DataAugmentation.py b/frame_semantic_transformer/data/augmentations/DataAugmentation.py new file mode 100644 index 0000000..bcf5d0b --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/DataAugmentation.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from random import uniform + + +class DataAugmentation(ABC): + """ + Base class for data augmentations on training data + """ + + probability: float + + def __init__(self, probability: float): + self.probability = probability + + def __call__(self, input: str, output: str) -> tuple[str, str]: + """ + randomly apply this augmentation in proportion to self.probability + """ + rand_val = uniform(0, 1.0) + if rand_val > self.probability: + return (input, output) + return self.apply_augmentation(input, output) + + @abstractmethod + def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: + """ + Main logic for subclasses to implement + """ + pass diff --git a/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py b/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py new file mode 100644 index 0000000..19c67a6 --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py @@ -0,0 +1,14 @@ +from __future__ import annotations +from .DataAugmentation import DataAugmentation + + +class LowercaseAugmentation(DataAugmentation): + def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: + task_def_index = input.find(":") + task_def = input[:task_def_index] + input_contents = input[task_def_index:] + # only lowercase the content, not the task definition + return ( + task_def + input_contents.lower(), + output.lower(), + ) diff --git a/frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py b/frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py new file mode 100644 index 0000000..66dd530 --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from .DataAugmentation import DataAugmentation +import re + + +def remove_contractions(text: str) -> str: + new_text = text.replace("won't", "will not") + new_text = new_text.replace("can't", "cannot") + new_text = re.sub(r"n't(\b)", r" not\1", new_text) + new_text = re.sub(r"'ll(\b)", r" will\1", new_text) + new_text = re.sub(r"'m(\b)", r" am\1", new_text) + new_text = re.sub(r"'re(\b)", r" are\1", new_text) + new_text = re.sub(r"'ve(\b)", r" have\1", new_text) + return new_text + + +class RemoveContractionsAugmentation(DataAugmentation): + def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: + if "*'" in input or "*'" in output or "*n'" in input or "*n'" in output: + return (input, output) + + return ( + remove_contractions(input), + remove_contractions(output), + ) diff --git a/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py b/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py new file mode 100644 index 0000000..5156d88 --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py @@ -0,0 +1,13 @@ +from __future__ import annotations +from .DataAugmentation import DataAugmentation +import re + +REMOVE_END_PUNCT_RE = r"\s*[.?!]\s*$" + + +class RemoveEndPunctuationAugmentation(DataAugmentation): + def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: + return ( + re.sub(REMOVE_END_PUNCT_RE, "", input), + re.sub(REMOVE_END_PUNCT_RE, "", output), + ) diff --git a/frame_semantic_transformer/data/augmentations/chain_augmentations.py b/frame_semantic_transformer/data/augmentations/chain_augmentations.py new file mode 100644 index 0000000..cb46ede --- /dev/null +++ b/frame_semantic_transformer/data/augmentations/chain_augmentations.py @@ -0,0 +1,17 @@ +from __future__ import annotations +from typing import Callable, Sequence + +from .DataAugmentation import DataAugmentation + + +def chain_augmentations( + augmentations: Sequence[DataAugmentation], +) -> Callable[[str, str], tuple[str, str]]: + def chained_augmentation(input: str, output: str) -> tuple[str, str]: + chained_input = input + chained_output = output + for augmentation in augmentations: + chained_input, chained_output = augmentation(chained_input, chained_output) + return chained_input, chained_output + + return chained_augmentation diff --git a/setup.cfg b/setup.cfg index ab84ba1..c8fefe6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,8 +15,14 @@ mypy_path = $MYPY_CONFIG_FILE_DIR/stubs [mypy-tests.*] ignore_missing_imports = True +[mypy-pytest.*] +ignore_missing_imports = True + [mypy-nltk.*] ignore_missing_imports = True +[mypy-flask_cors.*] +ignore_missing_imports = True + [mypy-tqdm.*] ignore_missing_imports = True \ No newline at end of file diff --git a/tests/data/augmentations/test_LowercaseAugmentation.py b/tests/data/augmentations/test_LowercaseAugmentation.py new file mode 100644 index 0000000..466cad5 --- /dev/null +++ b/tests/data/augmentations/test_LowercaseAugmentation.py @@ -0,0 +1,30 @@ +from __future__ import annotations +import pytest + +from frame_semantic_transformer.data.augmentations.LowercaseAugmentation import ( + LowercaseAugmentation, +) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + ("TASK: I am a banana.", "I am a banana."), + ("TASK: i am a banana.", "i am a banana."), + ), + ( + ("TASK: I AM A BANANA !", "I AM A BANANA !"), + ("TASK: i am a banana !", "i am a banana !"), + ), + ( + ("TASK | Param1 | Param 2 : I AM A BANANA !", "I AM A BANANA !"), + ("TASK | Param1 | Param 2 : i am a banana !", "i am a banana !"), + ), + ], +) +def test_LowercaseAugmentation( + input: tuple[str, str], expected: tuple[str, str] +) -> None: + augmentation = LowercaseAugmentation(1.0) + assert augmentation(*input) == expected diff --git a/tests/data/augmentations/test_RemoveContractionsAugmentation.py b/tests/data/augmentations/test_RemoveContractionsAugmentation.py new file mode 100644 index 0000000..c360815 --- /dev/null +++ b/tests/data/augmentations/test_RemoveContractionsAugmentation.py @@ -0,0 +1,36 @@ +from __future__ import annotations +import pytest + +from frame_semantic_transformer.data.augmentations.RemoveContractionsAugmentation import ( + RemoveContractionsAugmentation, +) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + ("TASK: I can't go I won't go", "I can't go I won't go"), + ("TASK: I cannot go I will not go", "I cannot go I will not go"), + ), + ( + ( + "TASK: shouldn't couldn't they're we'll they've", + "shouldn't couldn't they're we'll they've", + ), + ( + "TASK: should not could not they are we will they have", + "should not could not they are we will they have", + ), + ), + ( + ("TASK | Param1 | Param 2 : We're didn*'t", "We're didn't"), + ("TASK | Param1 | Param 2 : We're didn*'t", "We're didn't"), + ), + ], +) +def test_RemoveContractionsAugmentation( + input: tuple[str, str], expected: tuple[str, str] +) -> None: + augmentation = RemoveContractionsAugmentation(1.0) + assert augmentation(*input) == expected diff --git a/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py b/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py new file mode 100644 index 0000000..6f16077 --- /dev/null +++ b/tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py @@ -0,0 +1,30 @@ +from __future__ import annotations +import pytest + +from frame_semantic_transformer.data.augmentations.RemoveEndPunctuationAugmentation import ( + RemoveEndPunctuationAugmentation, +) + + +@pytest.mark.parametrize( + "input,expected", + [ + ( + ("TASK: I am a banana.", "I am a banana."), + ("TASK: I am a banana", "I am a banana"), + ), + ( + ("TASK: I am a banana!", "I am a banana!"), + ("TASK: I am a banana", "I am a banana"), + ), + ( + ("TASK: I am a banana .", "I am a banana ."), + ("TASK: I am a banana", "I am a banana"), + ), + ], +) +def test_RemoveEndPunctuationAugmentation( + input: tuple[str, str], expected: tuple[str, str] +) -> None: + augmentation = RemoveEndPunctuationAugmentation(1.0) + assert augmentation(*input) == expected