Skip to content

Commit

Permalink
adding in task mix balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 9, 2022
1 parent 12ccf38 commit ecde0b2
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 6 deletions.
35 changes: 33 additions & 2 deletions frame_semantic_transformer/data/TaskSampleDataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations
from collections import defaultdict
import random
from typing import Any, Sequence
import torch
from torch.utils.data import Dataset
Expand All @@ -17,8 +19,17 @@ class TaskSampleDataset(Dataset[Any]):
labels: torch.Tensor
samples: Sequence[TaskSample]

def __init__(self, samples: Sequence[TaskSample], tokenizer: T5Tokenizer):
input_ids, attention_mask, labels = parse_samples(samples, tokenizer)
def __init__(
self,
samples: Sequence[TaskSample],
tokenizer: T5Tokenizer,
balance_tasks: bool = False,
seed: int = 42,
):
samples_to_parse = samples
if balance_tasks:
samples_to_parse = balance_tasks_by_type(samples, seed)
input_ids, attention_mask, labels = parse_samples(samples_to_parse, tokenizer)
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels
Expand All @@ -35,6 +46,26 @@ def __getitem__(self, index: int) -> dict[str, Any]:
}


def balance_tasks_by_type(
samples: Sequence[TaskSample], seed: int
) -> Sequence[TaskSample]:
"""
try to force an approximate balance of task types by repeating tasks of uncommon types
"""
counts_by_type: dict[str, int] = defaultdict(int)
for sample in samples:
counts_by_type[sample.get_task_name()] += 1
max_task_count = max(counts_by_type.values())
balanced_samples: list[TaskSample] = []
for sample in samples:
sample_ratio = int(max_task_count / counts_by_type[sample.get_task_name()])
# duplicate each sample in proportion to how few tasks of this type are in the original mix
for _ in range(sample_ratio):
balanced_samples.append(sample)
random.Random(seed).shuffle(balanced_samples)
return balanced_samples


def parse_samples(
samples: Sequence[TaskSample], tokenizer: T5Tokenizer
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
13 changes: 10 additions & 3 deletions frame_semantic_transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,23 @@ def train(
lr: float = 1e-4,
num_workers: int = DEFAULT_NUM_WORKERS,
save_only_last_epoch: bool = False,
balance_tasks: bool = True,
) -> tuple[T5ForConditionalGeneration, T5Tokenizer]:
device = torch.device("cuda" if use_gpu else "cpu")
logging.info("loading base T5 model")
model = T5ForConditionalGeneration.from_pretrained(base_model).to(device)
tokenizer = T5Tokenizer.from_pretrained(base_model)

logging.info("loading train/test/val datasets")
train_dataset = TaskSampleDataset(load_sesame_train_samples(), tokenizer)
val_dataset = TaskSampleDataset(load_sesame_dev_samples(), tokenizer)
test_dataset = TaskSampleDataset(load_sesame_test_samples(), tokenizer)
train_dataset = TaskSampleDataset(
load_sesame_train_samples(), tokenizer, balance_tasks=balance_tasks
)
val_dataset = TaskSampleDataset(
load_sesame_dev_samples(), tokenizer, balance_tasks=balance_tasks
)
test_dataset = TaskSampleDataset(
load_sesame_test_samples(), tokenizer, balance_tasks=balance_tasks
)

data_module = TrainDataModule(
train_dataset=train_dataset,
Expand Down
42 changes: 41 additions & 1 deletion tests/data/test_TaskSampleDataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from __future__ import annotations
from transformers import T5Tokenizer
from frame_semantic_transformer.data.TaskSampleDataset import TaskSampleDataset
from frame_semantic_transformer.data.TaskSampleDataset import (
TaskSampleDataset,
balance_tasks_by_type,
)
from frame_semantic_transformer.data.framenet import get_fulltext_docs

from frame_semantic_transformer.data.load_framenet_samples import (
parse_samples_from_fulltext_doc,
)
from frame_semantic_transformer.data.task_samples.ArgumentsExtractionSample import (
ArgumentsExtractionSample,
)
from frame_semantic_transformer.data.task_samples.FrameClassificationSample import (
FrameClassificationSample,
)
from frame_semantic_transformer.data.task_samples.TriggerIdentificationSample import (
TriggerIdentificationSample,
)


def test_TaskSampleDataset() -> None:
Expand All @@ -21,3 +33,31 @@ def test_TaskSampleDataset() -> None:
assert len(dataset[0]["input_ids"]) == 55
assert len(dataset[0]["attention_mask"]) == 55
assert len(dataset[0]["labels"]) == 30


def test_balance_tasks_by_type() -> None:
tasks = [
ArgumentsExtractionSample("a1", (0, 2), "Greetings", []),
ArgumentsExtractionSample("a2", (0, 2), "Greetings", []),
ArgumentsExtractionSample("a3", (0, 2), "Greetings", []),
ArgumentsExtractionSample("a4", (0, 2), "Greetings", []),
ArgumentsExtractionSample("a5", (0, 2), "Greetings", []),
FrameClassificationSample("f1", (0, 2), "Greetings"),
FrameClassificationSample("f2", (0, 2), "Greetings"),
TriggerIdentificationSample("t1", []),
]
balanced_tasks = balance_tasks_by_type(tasks, 42)
assert len(balanced_tasks) == 14 # 5 arg, 4 frame, 5 trigger
assert (
len([t for t in balanced_tasks if t.get_task_name() == "args_extraction"]) == 5
)
assert (
len([t for t in balanced_tasks if t.get_task_name() == "frame_classification"])
== 4
)
assert (
len(
[t for t in balanced_tasks if t.get_task_name() == "trigger_identification"]
)
== 5
)

0 comments on commit ecde0b2

Please sign in to comment.