From d08877e0678c58abfe0d14e64c6673e68c03fb59 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 9 May 2022 15:33:19 +0100 Subject: [PATCH] adding a punct standardization step --- frame_semantic_transformer/data/chunk_list.py | 9 ---- frame_semantic_transformer/data/data_utils.py | 24 +++++++++++ .../data/shuffle_and_split.py | 15 ------- .../task_samples/FrameClassificationSample.py | 3 +- .../TriggerIdentificationSample.py | 5 ++- frame_semantic_transformer/evaluate.py | 2 +- .../test_FrameClassificationSample.py | 4 +- .../test_TriggerIdentificationSample.py | 14 +++--- tests/data/test_data_utils.py | 43 +++++++++++++++++++ 9 files changed, 81 insertions(+), 38 deletions(-) delete mode 100644 frame_semantic_transformer/data/chunk_list.py create mode 100644 frame_semantic_transformer/data/data_utils.py delete mode 100644 frame_semantic_transformer/data/shuffle_and_split.py create mode 100644 tests/data/test_data_utils.py diff --git a/frame_semantic_transformer/data/chunk_list.py b/frame_semantic_transformer/data/chunk_list.py deleted file mode 100644 index 849f3ee..0000000 --- a/frame_semantic_transformer/data/chunk_list.py +++ /dev/null @@ -1,9 +0,0 @@ -from __future__ import annotations -from typing import Iterator, Sequence, TypeVar - -T = TypeVar("T") - - -def chunk_list(lst: Sequence[T], chunk_size: int) -> Iterator[Sequence[T]]: - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] diff --git a/frame_semantic_transformer/data/data_utils.py b/frame_semantic_transformer/data/data_utils.py new file mode 100644 index 0000000..63b9870 --- /dev/null +++ b/frame_semantic_transformer/data/data_utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations +import re +from typing import Iterator, Sequence, TypeVar + +T = TypeVar("T") + + +def chunk_list(lst: Sequence[T], chunk_size: int) -> Iterator[Sequence[T]]: + for i in range(0, len(lst), chunk_size): + yield lst[i : i + chunk_size] + + +def standardize_punct(sent: str) -> str: + """ + Try to standardize things like "He 's a man" -> "He's a man" + """ + # remove space before punct + updated_sent = re.sub(r"([a-zA-Z0-9])\s+(\*?[.',:])", r"\1\2", sent) + # remove repeated *'s + updated_sent = re.sub(r"\*+", "*", updated_sent) + # fix spaces in contractions + updated_sent = re.sub(r"([a-zA-Z0-9])\s+(\*?n't)", r"\1\2", updated_sent) + + return updated_sent diff --git a/frame_semantic_transformer/data/shuffle_and_split.py b/frame_semantic_transformer/data/shuffle_and_split.py deleted file mode 100644 index f7a28dc..0000000 --- a/frame_semantic_transformer/data/shuffle_and_split.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations -from typing import Sequence, TypeVar -from random import Random - -T = TypeVar("T") - - -def shuffle_and_split( - data: Sequence[T], train_ratio: float = 0.8, seed: int = 0 -) -> tuple[Sequence[T], Sequence[T]]: - random = Random(seed) - shuffled = list(data) - random.shuffle(shuffled) - split_point = int(len(shuffled) * train_ratio) - return shuffled[:split_point], shuffled[split_point:] diff --git a/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py b/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py index 023e43b..08e00bf 100644 --- a/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py +++ b/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py @@ -1,5 +1,6 @@ from __future__ import annotations from dataclasses import dataclass +from frame_semantic_transformer.data.data_utils import standardize_punct from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample @@ -40,4 +41,4 @@ def trigger_labeled_text(self) -> str: pre_span = self.text[0 : self.trigger_loc[0]] post_span = self.text[self.trigger_loc[1] :] # TODO: handle these special chars better - return f"{pre_span}* {self.trigger} *{post_span}" + return standardize_punct(f"{pre_span}*{self.trigger}{post_span}") diff --git a/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py b/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py index 2f46732..ac51233 100644 --- a/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py +++ b/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py @@ -1,5 +1,6 @@ from __future__ import annotations from dataclasses import dataclass +from frame_semantic_transformer.data.data_utils import standardize_punct from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample @@ -25,14 +26,14 @@ def get_target(self) -> str: output += self.text[prev_trigger_loc:loc] + "*" prev_trigger_loc = loc output += self.text[prev_trigger_loc:] - return output + return standardize_punct(output) def evaluate_prediction(self, prediction: str) -> tuple[int, int, int]: true_pos = 0 false_pos = 0 false_neg = 0 - prediction_parts = prediction.split() + prediction_parts = standardize_punct(prediction).split() target_parts = self.get_target().split() for i, target_part in enumerate(target_parts): diff --git a/frame_semantic_transformer/evaluate.py b/frame_semantic_transformer/evaluate.py index a4a54b9..f0d06f1 100644 --- a/frame_semantic_transformer/evaluate.py +++ b/frame_semantic_transformer/evaluate.py @@ -4,7 +4,7 @@ from tqdm import tqdm from transformers import T5ForConditionalGeneration, T5Tokenizer -from frame_semantic_transformer.data.chunk_list import chunk_list +from frame_semantic_transformer.data.data_utils import chunk_list from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample from frame_semantic_transformer.predict import batch_predict diff --git a/tests/data/task_samples/test_FrameClassificationSample.py b/tests/data/task_samples/test_FrameClassificationSample.py index 7b4ad96..f52c47d 100644 --- a/tests/data/task_samples/test_FrameClassificationSample.py +++ b/tests/data/task_samples/test_FrameClassificationSample.py @@ -13,9 +13,7 @@ def test_get_input() -> None: - expected = ( - "FRAME: Your * contribution * to Goodwill will mean more than you may know ." - ) + expected = "FRAME: Your *contribution to Goodwill will mean more than you may know." assert sample.get_input() == expected diff --git a/tests/data/task_samples/test_TriggerIdentificationSample.py b/tests/data/task_samples/test_TriggerIdentificationSample.py index 92a717d..36958f3 100644 --- a/tests/data/task_samples/test_TriggerIdentificationSample.py +++ b/tests/data/task_samples/test_TriggerIdentificationSample.py @@ -6,38 +6,38 @@ sample = TriggerIdentificationSample( - text="Your contribution to Goodwill will mean more than you may know .", + text="Your contribution to Goodwill will mean more than you may know.", trigger_locs=[5, 18, 35, 40, 58, 54], ) def test_get_input() -> None: expected = ( - "TRIGGER: Your contribution to Goodwill will mean more than you may know ." + "TRIGGER: Your contribution to Goodwill will mean more than you may know." ) assert sample.get_input() == expected def test_get_target() -> None: - expected = "Your *contribution *to Goodwill will *mean *more than you *may *know ." + expected = "Your *contribution *to Goodwill will *mean *more than you *may *know." assert sample.get_target() == expected def test_evaluate_prediction() -> None: - pred = "Your contribution *to Goodwill *will *mean *more than you may *know ." + pred = "Your contribution *to Goodwill *will *mean *more than you may *know." assert sample.evaluate_prediction(pred) == (4, 1, 2) def test_evaluate_prediction_fails_for_elements_whose_content_doesnt_match() -> None: - pred = "Your AHAHAHAHA *to BADWILL will *PSYCH *more than you may *know ." + pred = "Your AHAHAHAHA *to BADWILL will *PSYCH *more than you may *know." assert sample.evaluate_prediction(pred) == (3, 1, 3) def test_evaluate_prediction_treats_missing_words_as_wrong() -> None: pred = "Your *contribution *to Goodwill will *mean" - assert sample.evaluate_prediction(pred) == (3, 3, 3) + assert sample.evaluate_prediction(pred) == (3, 2, 3) def test_evaluate_prediction_treats_excess_words_as_false_positives() -> None: - pred = "Your *contribution *to Goodwill will *mean *more than you *may *know . ha ha ha ha!" + pred = "Your *contribution *to Goodwill will *mean *more than you *may *know. ha ha ha ha!" assert sample.evaluate_prediction(pred) == (6, 4, 0) diff --git a/tests/data/test_data_utils.py b/tests/data/test_data_utils.py new file mode 100644 index 0000000..a17a819 --- /dev/null +++ b/tests/data/test_data_utils.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from frame_semantic_transformer.data.data_utils import standardize_punct + + +def test_standardize_punct_removes_spaces_before_punctuation() -> None: + original = "Old customs are still followed : Fate and luck are taken very seriously , and astrologers and fortune-tellers do a steady business ." + expected = "Old customs are still followed: Fate and luck are taken very seriously, and astrologers and fortune-tellers do a steady business." + assert standardize_punct(original) == expected + + +def test_standardize_punct_leaves_sentences_as_is_if_punct_is_correct() -> None: + sent = "Old customs are still followed: Fate and luck are taken very seriously, and astrologers and fortune-tellers do a steady business." + assert standardize_punct(sent) == sent + + +def test_standardize_punct_leaves_spaces_before_double_apostrophes() -> None: + sent = "I really *like my *job. '' -- Sherry" + assert standardize_punct(sent) == sent + + +def test_standardize_punct_keeps_asterix_before_apostrophes() -> None: + original = "*Shopping *never *ends - *there *'s *always *another inviting *spot" + expected = "*Shopping *never *ends - *there*'s *always *another inviting *spot" + assert standardize_punct(original) == expected + + +def test_standardize_punct_removes_repeated_asterixes() -> None: + original = "*Shopping **never *ends" + expected = "*Shopping *never *ends" + assert standardize_punct(original) == expected + + +def test_standardize_punct_undoes_spaces_in_contractions() -> None: + original = "She did n't say so" + expected = "She didn't say so" + assert standardize_punct(original) == expected + + +def test_standardize_punct_allows_asterix_in_contractions() -> None: + original = "She did *n't say so" + expected = "She did*n't say so" + assert standardize_punct(original) == expected