Skip to content

Commit

Permalink
feat: Propbank (#16)
Browse files Browse the repository at this point in the history
* setting up propbank loaders

* skip light verbs for now

* trying a different approach to avoid lv verbs

* fixing typo

* just ignore non-existent frames

* use similar lu norm to framenet

* adding option to resume from checkpoint

* switching to propbank 3.4 instead of 3.1

* fixing propbank nltk paths

* removing debuggin prints

* fixing test

* adding optional LR decay
  • Loading branch information
chanind authored Mar 9, 2023
1 parent 40983ff commit 4c53887
Show file tree
Hide file tree
Showing 14 changed files with 481 additions and 4 deletions.
5 changes: 4 additions & 1 deletion frame_semantic_transformer/data/LoaderDataCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def standardize_element_name(self, name: str) -> str | None:
"""
Standardize a frame element name
"""
return self.get_frame_element_name_loopkup().get(normalize_name(name))
norm_name = normalize_name(name)
if norm_name not in self.get_frame_element_name_loopkup():
return None if self.loader.strict_frame_elements() else name
return self.get_frame_element_name_loopkup()[norm_name]

@lru_cache(1)
def get_lexical_unit_bigram_to_frame_lookup_map(self) -> dict[str, list[str]]:
Expand Down
6 changes: 6 additions & 0 deletions frame_semantic_transformer/data/loaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def name(self) -> str:
"""
return self.__class__.__name__

def strict_frame_elements(self) -> bool:
"""
Return whether the loader strips out frame elements not in the frame definition.
"""
return True

def setup(self) -> None:
"""
Perform any setup required, e.g. downloading needed data
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations
import re

from nltk.stem import PorterStemmer

from .load_propbank_frames import load_propbank_frames

from .ensure_propbank_downloaded import ensure_propbank_downloaded


from frame_semantic_transformer.data.frame_types import Frame
from ..loader import InferenceLoader


base_stemmer = PorterStemmer()

LOW_PRIORITY_LONGER_LUS = {"back", "down", "make", "take", "have", "into", "come"}


class Propbank34InferenceLoader(InferenceLoader):
"""
Inference loader for Propbank 3.1 data
"""

def setup(self) -> None:
ensure_propbank_downloaded()

def strict_frame_elements(self) -> bool:
"""
Propbank only lists core roles, not all roles, so we can't enforce strict frame elements
"""
return False

def load_frames(self) -> list[Frame]:
"""
Load the full list of frames to be used during inference
"""
return load_propbank_frames()

def normalize_lexical_unit_text(self, lu: str) -> str:
"""
Normalize a lexical unit like "takes.v" to "take".
"""
normalized_lu = lu.lower().replace("_", " ")
normalized_lu = re.sub(r"\.[a-zA-Z]+$", "", normalized_lu)
normalized_lu = re.sub(r"[^a-z0-9 ]", "", normalized_lu)
return base_stemmer.stem(normalized_lu.strip())

def prioritize_lexical_unit(self, lu: str) -> bool:
"""
Check if the lexical unit is relatively rare, so that it should be considered "high information"
"""
norm_lu = self.normalize_lexical_unit_text(lu)
return len(norm_lu) >= 4 and norm_lu not in LOW_PRIORITY_LONGER_LUS
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations
from collections import defaultdict

from os import path
from glob import glob
import re

from nltk.corpus.reader.conll import ConllCorpusReader

from frame_semantic_transformer.data.augmentations import (
LowercaseAugmentation,
RemoveContractionsAugmentation,
RemoveEndPunctuationAugmentation,
)
from frame_semantic_transformer.data.augmentations.DataAugmentation import (
DataAugmentation,
)

from frame_semantic_transformer.data.frame_types import (
FrameAnnotatedSentence,
FrameAnnotation,
FrameElementAnnotation,
)
from ..loader import TrainingLoader
from .load_propbank_frames import load_propbank_frames


SPLITS = {
"train": [
"docs/evaluation/ewt.dev.txt",
"docs/evaluation/ontonotes-train-list.txt",
],
"val": ["docs/evaluation/ewt.dev.txt", "docs/evaluation/ontonotes-dev-list.txt"],
"test": ["docs/evaluation/ewt.test.txt", "docs/evaluation/ontonotes-test-list.txt"],
}
EWT_GLOB = "data/google/ewt/**/*.gold_conll"
ONTONOTES_GLOB = "data/ontonotes/**/*.gold_conll"


def load_docs_set(base_path: str, docs_list_paths: list[str]) -> list[str]:
docs_lookup = set()
for docs_list_path in docs_list_paths:
with open(path.join(base_path, docs_list_path)) as f:
raw_docs = f.read().splitlines()
# weirdly the ewt dev files end in .conllu but nothing else does
docs_lookup.update([doc.replace(".conllu", "") for doc in raw_docs])

docs = []
ewt_docs = glob(path.join(base_path, EWT_GLOB), recursive=True)
for doc in ewt_docs:
doc_base = re.sub(r".*/data/google/ewt/", "", doc).replace(".gold_conll", "")
if doc_base in docs_lookup:
docs.append(doc)

ontonotes_docs = glob(path.join(base_path, ONTONOTES_GLOB), recursive=True)
for doc in ontonotes_docs:
# for some reason the ontonotes list has 'ontonotes' in the path, but ewt doesn't have 'google/ewt'
doc_base = re.sub(r".*/data/ontonotes/", "ontonotes/", doc).replace(
".gold_conll", ""
)
if doc_base in docs_lookup:
docs.append(doc)
return docs


def conll_word_index_to_locs(words: list[str], word_index: int) -> tuple[int, int]:
"""
Take a list of words and an index of a word and return the start and end char indices of the word in the sentence
"""
start_loc = 0
for i, word in enumerate(words):
if i == word_index:
return start_loc, start_loc + len(word)
start_loc += len(word) + 1
raise ValueError("word index out of range")


def load_propbank_samples(
docs_list: list[str], valid_frames: set[str]
) -> list[FrameAnnotatedSentence]:
"""
Parse each of the propbank ontonotes and ewt gold conll files and return a list of FrameAnnotatedSentence objects
"""
annotated_sentences = []
for doc in docs_list:
conll_reader = ConllCorpusReader(
path.dirname(doc),
path.basename(doc),
("ignore", "ignore", "ignore", "words", "pos", "tree", "srl"),
)
sents_map = defaultdict(list)
for srl_instance in conll_reader.srl_instances():
words = [word[0] for word in srl_instance.words]
sentence = " ".join(words)
frame_name = srl_instance.verb_stem
if frame_name.lower() not in valid_frames:
continue
trigger_locs = [
conll_word_index_to_locs(words, index)[0] for index in srl_instance.verb
]

frame_elements = []
for argument in srl_instance.arguments:
words_range, frame_element_name = argument
element_start_loc = conll_word_index_to_locs(words, words_range[0])[0]
element_end_loc = conll_word_index_to_locs(words, words_range[1] - 1)[1]
frame_elements.append(
FrameElementAnnotation(
frame_element_name, element_start_loc, element_end_loc
)
)
sents_map[sentence].append(
FrameAnnotation(frame_name, trigger_locs, frame_elements)
)

for sentence, frame_annotations in sents_map.items():
annotated_sentences.append(
FrameAnnotatedSentence(sentence, frame_annotations)
)
return annotated_sentences


class Propbank34TrainingLoader(TrainingLoader):
"""
This loader uses ontonotes and ewt data from propbank 3.1 to train a model
You must clone https://github.com/propbank/propbank-release and set the propbank_release_dir to the path of the cloned repo
You must also download the LDC data for ontonotes and ewt, and run map_all_to_conll.py as described in the propbank repo
Sadly, this data isn't free so you'll need to get it yourself before working with this loader.
"""

propbank_release_dir: str
train_docs: list[str] = []
val_docs: list[str] = []
test_docs: list[str] = []
valid_frames: set[str] = set()

def __init__(self, propbank_release_dir: str) -> None:
super().__init__()
self.propbank_release_dir = propbank_release_dir

def setup(self) -> None:
self.valid_frames = {frame.name.lower() for frame in load_propbank_frames()}
self.train_docs = load_docs_set(self.propbank_release_dir, SPLITS["train"])
self.val_docs = load_docs_set(self.propbank_release_dir, SPLITS["val"])
self.test_docs = load_docs_set(self.propbank_release_dir, SPLITS["test"])

def get_augmentations(self) -> list[DataAugmentation]:
return [
RemoveEndPunctuationAugmentation(0.3),
LowercaseAugmentation(0.2),
RemoveContractionsAugmentation(0.2),
]

def load_training_data(self) -> list[FrameAnnotatedSentence]:
return load_propbank_samples(self.train_docs, self.valid_frames)

def load_test_data(self) -> list[FrameAnnotatedSentence]:
return load_propbank_samples(self.test_docs, self.valid_frames)

def load_validation_data(self) -> list[FrameAnnotatedSentence]:
return load_propbank_samples(self.val_docs, self.valid_frames)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .Propbank34InferenceLoader import Propbank34InferenceLoader
from .Propbank34TrainingLoader import Propbank34TrainingLoader


__all__ = [
"Propbank34InferenceLoader",
"Propbank34TrainingLoader",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import nltk
from nltk.downloader import Package


# NLTK only has v1.0 of PropBank, so hackily create a NLTK package and download v3.1
propbank34 = Package(
id="propbank-frames-3.4.0",
url="https://github.com/propbank/propbank-frames/archive/refs/tags/v3.4.0.zip",
name="Proposition Bank Corpus 3.4",
checksum="e563f8c9912d53ed7e709455746875e5",
subdir="corpora",
size=9484561,
unzipped_size=29870379,
)


def ensure_propbank_downloaded() -> None:
try:
nltk.data.find("corpora/propbank-frames-3.4.0.zip")
except LookupError:
nltk.download(propbank34)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import nltk
from glob import glob
from os import path
from xml.etree import ElementTree


from frame_semantic_transformer.data.frame_types import Frame


def load_propbank_frames() -> list[Frame]:
"""
Load the full list of frames to be used during inference
"""
dataset_path = nltk.data.find("corpora/propbank-frames-3.4.0").path
frames_paths = glob(path.join(dataset_path, "frames", "*.xml"))
frames = []
for frame_path in frames_paths:
with open(frame_path, "r") as frame_file:
etree = ElementTree.parse(frame_file).getroot()
raw_frames = etree.findall("predicate/roleset")
for raw_frame in raw_frames:
frame = Frame(
name=raw_frame.attrib["id"],
core_elements=[
f"ARG{role.attrib['n']}-{role.attrib['f']}"
for role in raw_frame.findall("roles/role")
],
non_core_elements=[],
lexical_units=[
f"{alias.text}.{alias.attrib['pos']}"
for alias in raw_frame.findall("aliases/alias")
],
)
frames.append(frame)
return frames
10 changes: 8 additions & 2 deletions frame_semantic_transformer/training/TrainingModelWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast

from frame_semantic_transformer.data.LoaderDataCache import LoaderDataCache
Expand All @@ -28,6 +29,7 @@ class TrainingModelWrapper(pl.LightningModule):
skip_initial_epochs_validation: int
loader_cache: LoaderDataCache
val_metrics: dict[str, float] | None
lr_gamma: float
log_eval_failures: bool

def __init__(
Expand All @@ -39,6 +41,7 @@ def __init__(
output_dir: str = "outputs",
save_only_last_epoch: bool = False,
skip_initial_epochs_validation: int = 0,
lr_gamma: float = 1.0,
log_eval_failures: bool = False,
):
super().__init__()
Expand All @@ -50,6 +53,7 @@ def __init__(
self.save_only_last_epoch = save_only_last_epoch
self.skip_initial_epochs_validation = skip_initial_epochs_validation
self.val_metrics = None
self.lr_gamma = lr_gamma
self.log_eval_failures = log_eval_failures

def forward(
Expand Down Expand Up @@ -114,8 +118,10 @@ def test_step(self, batch: Any, _batch_idx: int) -> Any: # type: ignore
)
return {"loss": loss, "metrics": metrics}

def configure_optimizers(self) -> AdamW:
return AdamW(self.parameters(), lr=self.lr)
def configure_optimizers(self) -> tuple[list[AdamW], list[ExponentialLR]]:
optimizer = AdamW(self.parameters(), lr=self.lr)
scheduler = ExponentialLR(optimizer, gamma=self.lr_gamma, verbose=True)
return [optimizer], [scheduler]

def training_epoch_end(self, training_step_outputs: list[Any]) -> None:
"""save tokenizer and model on epoch end"""
Expand Down
7 changes: 6 additions & 1 deletion frame_semantic_transformer/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def train(
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
precision: Union[Literal[64, 32, 16], Literal["64", "32", "16", "bf16"]] = 32,
lr: float = 1e-4,
lr_gamma: float = 1.0,
num_workers: int = DEFAULT_NUM_WORKERS,
save_only_last_epoch: bool = False,
balance_tasks: bool = True,
Expand All @@ -52,6 +53,7 @@ def train(
training_loader: Optional[TrainingLoader] = None,
pl_callbacks: Optional[list[Callback]] = None,
pl_loggers: Optional[list[Logger]] = None,
resume_from_checkpoint: Optional[str] = None,
) -> tuple[T5ForConditionalGeneration, T5TokenizerFast]:
device = torch.device("cuda" if use_gpu else "cpu")
logger.info("loading base T5 model")
Expand Down Expand Up @@ -106,6 +108,7 @@ def train(
model,
tokenizer,
lr=lr,
lr_gamma=lr_gamma,
output_dir=output_dir,
save_only_last_epoch=save_only_last_epoch,
skip_initial_epochs_validation=skip_initial_epochs_validation,
Expand Down Expand Up @@ -140,10 +143,12 @@ def train(
trainer = pl.Trainer(
callbacks=callbacks,
max_epochs=max_epochs,
gpus=1 if use_gpu else 0,
accelerator="gpu" if use_gpu else "cpu",
devices=1 if use_gpu else "auto",
precision=precision,
log_every_n_steps=1,
logger=pl_loggers or True,
resume_from_checkpoint=resume_from_checkpoint,
)

logger.info("beginning training")
Expand Down
Loading

0 comments on commit 4c53887

Please sign in to comment.