Skip to content

Commit

Permalink
try adding in all possible frame elements into task intro for argumen…
Browse files Browse the repository at this point in the history
…t extraction
  • Loading branch information
chanind committed May 14, 2022
1 parent 39376e5 commit 94e3c89
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 14 deletions.
12 changes: 11 additions & 1 deletion frame_semantic_transformer/data/framenet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations
from functools import lru_cache
from typing import Any, Sequence, Mapping
from typing import Any, Iterable, Sequence, Mapping
import nltk

from nltk.corpus import framenet as fn


class InvalidFrameError(Exception):
pass


def ensure_framenet_downloaded() -> None:
nltk.download("framenet_v17")

Expand All @@ -14,6 +18,12 @@ def is_valid_frame(frame: str) -> bool:
return frame in get_all_valid_frame_names()


def get_frame_element_names(frame: str) -> Iterable[str]:
if not is_valid_frame(frame):
raise InvalidFrameError(frame)
return fn.frame(frame).FE.keys()


@lru_cache(1)
def get_all_valid_frame_names() -> set[str]:
return {frame.name for frame in fn.frames()}
Expand Down
3 changes: 0 additions & 3 deletions frame_semantic_transformer/data/load_framenet_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def parse_frame_samples_from_annotation_set(
frame=annotation["frame"]["name"],
)
)
if annotation["FE"][1] != {}:
# I don't understand what the second part of this tuple is, just ignore it for now
continue
sample_sentences.append(
ArgumentsExtractionSample(
text=annotation["text"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from frame_semantic_transformer.data.framenet import get_frame_element_names

from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample

Expand All @@ -18,7 +19,8 @@ def get_task_name(self) -> str:
return "args_extraction"

def get_input(self) -> str:
return f"ARGS {self.frame}: {self.trigger_labeled_text}"
elements = get_frame_element_names(self.frame)
return f"ARGS {self.frame} | {' '.join(elements)} : {self.trigger_labeled_text}"

def get_target(self) -> str:
return " | ".join(
Expand Down
24 changes: 21 additions & 3 deletions frame_semantic_transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ def training_step(self, batch: Any, _batch_idx: int) -> Any: # type: ignore
output = self._step(batch)
loss = output.loss
self.log(
"train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
"train_loss",
loss,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=True,
batch_size=len(batch["input_ids"]),
)
return loss

Expand All @@ -132,15 +138,27 @@ def validation_step(self, batch: Any, _batch_idx: int) -> Any: # type: ignore
loss = output.loss
metrics = evaluate_batch(self.model, self.tokenizer, batch)
self.log(
"val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
"val_loss",
loss,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=True,
batch_size=len(batch["input_ids"]),
)
return {"loss": loss, "metrics": metrics}

def test_step(self, batch: Any, _batch_idx: int) -> Any: # type: ignore
output = self._step(batch)
loss = output.loss
metrics = evaluate_batch(self.model, self.tokenizer, batch)
self.log("test_loss", loss, prog_bar=True, logger=True)
self.log(
"test_loss",
loss,
prog_bar=True,
logger=True,
batch_size=len(batch["input_ids"]),
)
return {"loss": loss, "metrics": metrics}

def configure_optimizers(self) -> AdamW:
Expand Down
39 changes: 39 additions & 0 deletions tests/data/__snapshots__/test_load_framenet_samples.ambr

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion tests/data/task_samples/test_ArgumentsExtractionSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@


def test_get_input() -> None:
expected = "ARGS Giving: Your * contribution * to Goodwill will mean more than you may know ."
elements = "Donor Recipient Theme Place Explanation Time Purpose Means Manner Circumstances Imposed_purpose Depictive Period_of_iterations"
expected = f"ARGS Giving | {elements} : Your * contribution * to Goodwill will mean more than you may know ."
assert sample.get_input() == expected


Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_TaskSampleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_TaskSampleDataset() -> None:
dataset = TaskSampleDataset(samples, tokenizer)

assert len(dataset) == 8
assert len(dataset[0]["input_ids"]) == 55
assert len(dataset[0]["attention_mask"]) == 55
assert len(dataset[0]["input_ids"]) == 105
assert len(dataset[0]["attention_mask"]) == 105
assert len(dataset[0]["labels"]) == 30


Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_load_framenet_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

def test_load_sesame_test_samples() -> None:
samples = load_sesame_test_samples()
assert len(samples) == 13510
assert len(samples) == 15126


def test_load_sesame_dev_samples() -> None:
samples = load_sesame_dev_samples()
assert len(samples) == 4613
assert len(samples) == 5166


def test_load_sesame_train_samples() -> None:
Expand All @@ -32,7 +32,7 @@ def test_load_sesame_train_samples() -> None:
]
assert len(trigger_id_samples) == 3425
assert len(frame_id_samples) == 20597
assert len(samples) == 40233
assert len(samples) == 44619


def test_parse_samples_from_fulltext_doc(snapshot: SnapshotAssertion) -> None:
Expand Down

0 comments on commit 94e3c89

Please sign in to comment.