Skip to content

Commit

Permalink
adding data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 20, 2022
1 parent bda0ca8 commit 13de3f4
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 4 deletions.
36 changes: 32 additions & 4 deletions frame_semantic_transformer/data/TaskSampleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
import torch
from torch.utils.data import Dataset
from transformers import T5Tokenizer
from frame_semantic_transformer.data.augmentations.LowercaseAugmentation import (
LowercaseAugmentation,
)
from frame_semantic_transformer.data.augmentations.RemoveContractionsAugmentation import (
RemoveContractionsAugmentation,
)
from frame_semantic_transformer.data.augmentations.RemoveEndPunctuationAugmentation import (
RemoveEndPunctuationAugmentation,
)
from frame_semantic_transformer.data.augmentations.chain_augmentations import (
chain_augmentations,
)

from frame_semantic_transformer.data.tasks.TaskSample import TaskSample

Expand All @@ -26,13 +38,16 @@ def __init__(
balance_tasks: bool = False,
seed: int = 42,
max_task_duplication_factor: int = 2,
augment_data: bool = False,
):
samples_to_parse = samples
if balance_tasks:
samples_to_parse = balance_tasks_by_type(
samples, seed=seed, max_duplication_factor=max_task_duplication_factor
)
input_ids, attention_mask, labels = parse_samples(samples_to_parse, tokenizer)
input_ids, attention_mask, labels = parse_samples(
samples_to_parse, tokenizer, augment_data
)
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels
Expand Down Expand Up @@ -75,13 +90,26 @@ def balance_tasks_by_type(


def parse_samples(
samples: Sequence[TaskSample], tokenizer: T5Tokenizer
samples: Sequence[TaskSample], tokenizer: T5Tokenizer, augment_data: bool
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_sequences: list[str] = []
output_sequences: list[str] = []

augmentation = chain_augmentations(
[
RemoveEndPunctuationAugmentation(0.3),
LowercaseAugmentation(0.2),
RemoveContractionsAugmentation(0.2),
]
)

for sample in samples:
input_sequences.append(sample.get_input())
output_sequences.append(sample.get_target())
input = sample.get_input()
output = sample.get_target()
if augment_data:
input, output = augmentation(input, output)
input_sequences.append(input)
output_sequences.append(output)

input_encoding = tokenizer(
input_sequences,
Expand Down
30 changes: 30 additions & 0 deletions frame_semantic_transformer/data/augmentations/DataAugmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from random import uniform


class DataAugmentation(ABC):
"""
Base class for data augmentations on training data
"""

probability: float

def __init__(self, probability: float):
self.probability = probability

def __call__(self, input: str, output: str) -> tuple[str, str]:
"""
randomly apply this augmentation in proportion to self.probability
"""
rand_val = uniform(0, 1.0)
if rand_val > self.probability:
return (input, output)
return self.apply_augmentation(input, output)

@abstractmethod
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
"""
Main logic for subclasses to implement
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import annotations
from .DataAugmentation import DataAugmentation


class LowercaseAugmentation(DataAugmentation):
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
task_def_index = input.find(":")
task_def = input[:task_def_index]
input_contents = input[task_def_index:]
# only lowercase the content, not the task definition
return (
task_def + input_contents.lower(),
output.lower(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations
from .DataAugmentation import DataAugmentation
import re


def remove_contractions(text: str) -> str:
new_text = text.replace("won't", "will not")
new_text = new_text.replace("can't", "cannot")
new_text = re.sub(r"n't(\b)", r" not\1", new_text)
new_text = re.sub(r"'ll(\b)", r" will\1", new_text)
new_text = re.sub(r"'m(\b)", r" am\1", new_text)
new_text = re.sub(r"'re(\b)", r" are\1", new_text)
new_text = re.sub(r"'ve(\b)", r" have\1", new_text)
return new_text


class RemoveContractionsAugmentation(DataAugmentation):
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
if "*'" in input or "*'" in output or "*n'" in input or "*n'" in output:
return (input, output)

return (
remove_contractions(input),
remove_contractions(output),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations
from .DataAugmentation import DataAugmentation
import re

REMOVE_END_PUNCT_RE = r"\s*[.?!]\s*$"


class RemoveEndPunctuationAugmentation(DataAugmentation):
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]:
return (
re.sub(REMOVE_END_PUNCT_RE, "", input),
re.sub(REMOVE_END_PUNCT_RE, "", output),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations
from typing import Callable, Sequence

from .DataAugmentation import DataAugmentation


def chain_augmentations(
augmentations: Sequence[DataAugmentation],
) -> Callable[[str, str], tuple[str, str]]:
def chained_augmentation(input: str, output: str) -> tuple[str, str]:
chained_input = input
chained_output = output
for augmentation in augmentations:
chained_input, chained_output = augmentation(chained_input, chained_output)
return chained_input, chained_output

return chained_augmentation
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ mypy_path = $MYPY_CONFIG_FILE_DIR/stubs
[mypy-tests.*]
ignore_missing_imports = True

[mypy-pytest.*]
ignore_missing_imports = True

[mypy-nltk.*]
ignore_missing_imports = True

[mypy-flask_cors.*]
ignore_missing_imports = True

[mypy-tqdm.*]
ignore_missing_imports = True
30 changes: 30 additions & 0 deletions tests/data/augmentations/test_LowercaseAugmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations
import pytest

from frame_semantic_transformer.data.augmentations.LowercaseAugmentation import (
LowercaseAugmentation,
)


@pytest.mark.parametrize(
"input,expected",
[
(
("TASK: I am a banana.", "I am a banana."),
("TASK: i am a banana.", "i am a banana."),
),
(
("TASK: I AM A BANANA !", "I AM A BANANA !"),
("TASK: i am a banana !", "i am a banana !"),
),
(
("TASK | Param1 | Param 2 : I AM A BANANA !", "I AM A BANANA !"),
("TASK | Param1 | Param 2 : i am a banana !", "i am a banana !"),
),
],
)
def test_LowercaseAugmentation(
input: tuple[str, str], expected: tuple[str, str]
) -> None:
augmentation = LowercaseAugmentation(1.0)
assert augmentation(*input) == expected
36 changes: 36 additions & 0 deletions tests/data/augmentations/test_RemoveContractionsAugmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations
import pytest

from frame_semantic_transformer.data.augmentations.RemoveContractionsAugmentation import (
RemoveContractionsAugmentation,
)


@pytest.mark.parametrize(
"input,expected",
[
(
("TASK: I can't go I won't go", "I can't go I won't go"),
("TASK: I cannot go I will not go", "I cannot go I will not go"),
),
(
(
"TASK: shouldn't couldn't they're we'll they've",
"shouldn't couldn't they're we'll they've",
),
(
"TASK: should not could not they are we will they have",
"should not could not they are we will they have",
),
),
(
("TASK | Param1 | Param 2 : We're didn*'t", "We're didn't"),
("TASK | Param1 | Param 2 : We're didn*'t", "We're didn't"),
),
],
)
def test_RemoveContractionsAugmentation(
input: tuple[str, str], expected: tuple[str, str]
) -> None:
augmentation = RemoveContractionsAugmentation(1.0)
assert augmentation(*input) == expected
30 changes: 30 additions & 0 deletions tests/data/augmentations/test_RemoveEndPunctuationAugmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations
import pytest

from frame_semantic_transformer.data.augmentations.RemoveEndPunctuationAugmentation import (
RemoveEndPunctuationAugmentation,
)


@pytest.mark.parametrize(
"input,expected",
[
(
("TASK: I am a banana.", "I am a banana."),
("TASK: I am a banana", "I am a banana"),
),
(
("TASK: I am a banana!", "I am a banana!"),
("TASK: I am a banana", "I am a banana"),
),
(
("TASK: I am a banana .", "I am a banana ."),
("TASK: I am a banana", "I am a banana"),
),
],
)
def test_RemoveEndPunctuationAugmentation(
input: tuple[str, str], expected: tuple[str, str]
) -> None:
augmentation = RemoveEndPunctuationAugmentation(1.0)
assert augmentation(*input) == expected

0 comments on commit 13de3f4

Please sign in to comment.