From ecde0b2005c3f7a30698fa88528a3d41cd34e6cc Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 9 May 2022 21:35:50 +0100 Subject: [PATCH] adding in task mix balancing --- .../data/TaskSampleDataset.py | 35 +++++++++++++++- frame_semantic_transformer/train.py | 13 ++++-- tests/data/test_TaskSampleDataset.py | 42 ++++++++++++++++++- 3 files changed, 84 insertions(+), 6 deletions(-) diff --git a/frame_semantic_transformer/data/TaskSampleDataset.py b/frame_semantic_transformer/data/TaskSampleDataset.py index 5bc40fb..c042d64 100644 --- a/frame_semantic_transformer/data/TaskSampleDataset.py +++ b/frame_semantic_transformer/data/TaskSampleDataset.py @@ -1,4 +1,6 @@ from __future__ import annotations +from collections import defaultdict +import random from typing import Any, Sequence import torch from torch.utils.data import Dataset @@ -17,8 +19,17 @@ class TaskSampleDataset(Dataset[Any]): labels: torch.Tensor samples: Sequence[TaskSample] - def __init__(self, samples: Sequence[TaskSample], tokenizer: T5Tokenizer): - input_ids, attention_mask, labels = parse_samples(samples, tokenizer) + def __init__( + self, + samples: Sequence[TaskSample], + tokenizer: T5Tokenizer, + balance_tasks: bool = False, + seed: int = 42, + ): + samples_to_parse = samples + if balance_tasks: + samples_to_parse = balance_tasks_by_type(samples, seed) + input_ids, attention_mask, labels = parse_samples(samples_to_parse, tokenizer) self.input_ids = input_ids self.attention_mask = attention_mask self.labels = labels @@ -35,6 +46,26 @@ def __getitem__(self, index: int) -> dict[str, Any]: } +def balance_tasks_by_type( + samples: Sequence[TaskSample], seed: int +) -> Sequence[TaskSample]: + """ + try to force an approximate balance of task types by repeating tasks of uncommon types + """ + counts_by_type: dict[str, int] = defaultdict(int) + for sample in samples: + counts_by_type[sample.get_task_name()] += 1 + max_task_count = max(counts_by_type.values()) + balanced_samples: list[TaskSample] = [] + for sample in samples: + sample_ratio = int(max_task_count / counts_by_type[sample.get_task_name()]) + # duplicate each sample in proportion to how few tasks of this type are in the original mix + for _ in range(sample_ratio): + balanced_samples.append(sample) + random.Random(seed).shuffle(balanced_samples) + return balanced_samples + + def parse_samples( samples: Sequence[TaskSample], tokenizer: T5Tokenizer ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/frame_semantic_transformer/train.py b/frame_semantic_transformer/train.py index 76e780c..6d1ee0d 100644 --- a/frame_semantic_transformer/train.py +++ b/frame_semantic_transformer/train.py @@ -175,6 +175,7 @@ def train( lr: float = 1e-4, num_workers: int = DEFAULT_NUM_WORKERS, save_only_last_epoch: bool = False, + balance_tasks: bool = True, ) -> tuple[T5ForConditionalGeneration, T5Tokenizer]: device = torch.device("cuda" if use_gpu else "cpu") logging.info("loading base T5 model") @@ -182,9 +183,15 @@ def train( tokenizer = T5Tokenizer.from_pretrained(base_model) logging.info("loading train/test/val datasets") - train_dataset = TaskSampleDataset(load_sesame_train_samples(), tokenizer) - val_dataset = TaskSampleDataset(load_sesame_dev_samples(), tokenizer) - test_dataset = TaskSampleDataset(load_sesame_test_samples(), tokenizer) + train_dataset = TaskSampleDataset( + load_sesame_train_samples(), tokenizer, balance_tasks=balance_tasks + ) + val_dataset = TaskSampleDataset( + load_sesame_dev_samples(), tokenizer, balance_tasks=balance_tasks + ) + test_dataset = TaskSampleDataset( + load_sesame_test_samples(), tokenizer, balance_tasks=balance_tasks + ) data_module = TrainDataModule( train_dataset=train_dataset, diff --git a/tests/data/test_TaskSampleDataset.py b/tests/data/test_TaskSampleDataset.py index c49d453..a3eff27 100644 --- a/tests/data/test_TaskSampleDataset.py +++ b/tests/data/test_TaskSampleDataset.py @@ -1,11 +1,23 @@ from __future__ import annotations from transformers import T5Tokenizer -from frame_semantic_transformer.data.TaskSampleDataset import TaskSampleDataset +from frame_semantic_transformer.data.TaskSampleDataset import ( + TaskSampleDataset, + balance_tasks_by_type, +) from frame_semantic_transformer.data.framenet import get_fulltext_docs from frame_semantic_transformer.data.load_framenet_samples import ( parse_samples_from_fulltext_doc, ) +from frame_semantic_transformer.data.task_samples.ArgumentsExtractionSample import ( + ArgumentsExtractionSample, +) +from frame_semantic_transformer.data.task_samples.FrameClassificationSample import ( + FrameClassificationSample, +) +from frame_semantic_transformer.data.task_samples.TriggerIdentificationSample import ( + TriggerIdentificationSample, +) def test_TaskSampleDataset() -> None: @@ -21,3 +33,31 @@ def test_TaskSampleDataset() -> None: assert len(dataset[0]["input_ids"]) == 55 assert len(dataset[0]["attention_mask"]) == 55 assert len(dataset[0]["labels"]) == 30 + + +def test_balance_tasks_by_type() -> None: + tasks = [ + ArgumentsExtractionSample("a1", (0, 2), "Greetings", []), + ArgumentsExtractionSample("a2", (0, 2), "Greetings", []), + ArgumentsExtractionSample("a3", (0, 2), "Greetings", []), + ArgumentsExtractionSample("a4", (0, 2), "Greetings", []), + ArgumentsExtractionSample("a5", (0, 2), "Greetings", []), + FrameClassificationSample("f1", (0, 2), "Greetings"), + FrameClassificationSample("f2", (0, 2), "Greetings"), + TriggerIdentificationSample("t1", []), + ] + balanced_tasks = balance_tasks_by_type(tasks, 42) + assert len(balanced_tasks) == 14 # 5 arg, 4 frame, 5 trigger + assert ( + len([t for t in balanced_tasks if t.get_task_name() == "args_extraction"]) == 5 + ) + assert ( + len([t for t in balanced_tasks if t.get_task_name() == "frame_classification"]) + == 4 + ) + assert ( + len( + [t for t in balanced_tasks if t.get_task_name() == "trigger_identification"] + ) + == 5 + )