-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* setting up propbank loaders * skip light verbs for now * trying a different approach to avoid lv verbs * fixing typo * just ignore non-existent frames * use similar lu norm to framenet * adding option to resume from checkpoint * switching to propbank 3.4 instead of 3.1 * fixing propbank nltk paths * removing debuggin prints * fixing test * adding optional LR decay
- Loading branch information
Showing
14 changed files
with
481 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
frame_semantic_transformer/data/loaders/propbank34/Propbank34InferenceLoader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
import re | ||
|
||
from nltk.stem import PorterStemmer | ||
|
||
from .load_propbank_frames import load_propbank_frames | ||
|
||
from .ensure_propbank_downloaded import ensure_propbank_downloaded | ||
|
||
|
||
from frame_semantic_transformer.data.frame_types import Frame | ||
from ..loader import InferenceLoader | ||
|
||
|
||
base_stemmer = PorterStemmer() | ||
|
||
LOW_PRIORITY_LONGER_LUS = {"back", "down", "make", "take", "have", "into", "come"} | ||
|
||
|
||
class Propbank34InferenceLoader(InferenceLoader): | ||
""" | ||
Inference loader for Propbank 3.1 data | ||
""" | ||
|
||
def setup(self) -> None: | ||
ensure_propbank_downloaded() | ||
|
||
def strict_frame_elements(self) -> bool: | ||
""" | ||
Propbank only lists core roles, not all roles, so we can't enforce strict frame elements | ||
""" | ||
return False | ||
|
||
def load_frames(self) -> list[Frame]: | ||
""" | ||
Load the full list of frames to be used during inference | ||
""" | ||
return load_propbank_frames() | ||
|
||
def normalize_lexical_unit_text(self, lu: str) -> str: | ||
""" | ||
Normalize a lexical unit like "takes.v" to "take". | ||
""" | ||
normalized_lu = lu.lower().replace("_", " ") | ||
normalized_lu = re.sub(r"\.[a-zA-Z]+$", "", normalized_lu) | ||
normalized_lu = re.sub(r"[^a-z0-9 ]", "", normalized_lu) | ||
return base_stemmer.stem(normalized_lu.strip()) | ||
|
||
def prioritize_lexical_unit(self, lu: str) -> bool: | ||
""" | ||
Check if the lexical unit is relatively rare, so that it should be considered "high information" | ||
""" | ||
norm_lu = self.normalize_lexical_unit_text(lu) | ||
return len(norm_lu) >= 4 and norm_lu not in LOW_PRIORITY_LONGER_LUS |
161 changes: 161 additions & 0 deletions
161
frame_semantic_transformer/data/loaders/propbank34/Propbank34TrainingLoader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from __future__ import annotations | ||
from collections import defaultdict | ||
|
||
from os import path | ||
from glob import glob | ||
import re | ||
|
||
from nltk.corpus.reader.conll import ConllCorpusReader | ||
|
||
from frame_semantic_transformer.data.augmentations import ( | ||
LowercaseAugmentation, | ||
RemoveContractionsAugmentation, | ||
RemoveEndPunctuationAugmentation, | ||
) | ||
from frame_semantic_transformer.data.augmentations.DataAugmentation import ( | ||
DataAugmentation, | ||
) | ||
|
||
from frame_semantic_transformer.data.frame_types import ( | ||
FrameAnnotatedSentence, | ||
FrameAnnotation, | ||
FrameElementAnnotation, | ||
) | ||
from ..loader import TrainingLoader | ||
from .load_propbank_frames import load_propbank_frames | ||
|
||
|
||
SPLITS = { | ||
"train": [ | ||
"docs/evaluation/ewt.dev.txt", | ||
"docs/evaluation/ontonotes-train-list.txt", | ||
], | ||
"val": ["docs/evaluation/ewt.dev.txt", "docs/evaluation/ontonotes-dev-list.txt"], | ||
"test": ["docs/evaluation/ewt.test.txt", "docs/evaluation/ontonotes-test-list.txt"], | ||
} | ||
EWT_GLOB = "data/google/ewt/**/*.gold_conll" | ||
ONTONOTES_GLOB = "data/ontonotes/**/*.gold_conll" | ||
|
||
|
||
def load_docs_set(base_path: str, docs_list_paths: list[str]) -> list[str]: | ||
docs_lookup = set() | ||
for docs_list_path in docs_list_paths: | ||
with open(path.join(base_path, docs_list_path)) as f: | ||
raw_docs = f.read().splitlines() | ||
# weirdly the ewt dev files end in .conllu but nothing else does | ||
docs_lookup.update([doc.replace(".conllu", "") for doc in raw_docs]) | ||
|
||
docs = [] | ||
ewt_docs = glob(path.join(base_path, EWT_GLOB), recursive=True) | ||
for doc in ewt_docs: | ||
doc_base = re.sub(r".*/data/google/ewt/", "", doc).replace(".gold_conll", "") | ||
if doc_base in docs_lookup: | ||
docs.append(doc) | ||
|
||
ontonotes_docs = glob(path.join(base_path, ONTONOTES_GLOB), recursive=True) | ||
for doc in ontonotes_docs: | ||
# for some reason the ontonotes list has 'ontonotes' in the path, but ewt doesn't have 'google/ewt' | ||
doc_base = re.sub(r".*/data/ontonotes/", "ontonotes/", doc).replace( | ||
".gold_conll", "" | ||
) | ||
if doc_base in docs_lookup: | ||
docs.append(doc) | ||
return docs | ||
|
||
|
||
def conll_word_index_to_locs(words: list[str], word_index: int) -> tuple[int, int]: | ||
""" | ||
Take a list of words and an index of a word and return the start and end char indices of the word in the sentence | ||
""" | ||
start_loc = 0 | ||
for i, word in enumerate(words): | ||
if i == word_index: | ||
return start_loc, start_loc + len(word) | ||
start_loc += len(word) + 1 | ||
raise ValueError("word index out of range") | ||
|
||
|
||
def load_propbank_samples( | ||
docs_list: list[str], valid_frames: set[str] | ||
) -> list[FrameAnnotatedSentence]: | ||
""" | ||
Parse each of the propbank ontonotes and ewt gold conll files and return a list of FrameAnnotatedSentence objects | ||
""" | ||
annotated_sentences = [] | ||
for doc in docs_list: | ||
conll_reader = ConllCorpusReader( | ||
path.dirname(doc), | ||
path.basename(doc), | ||
("ignore", "ignore", "ignore", "words", "pos", "tree", "srl"), | ||
) | ||
sents_map = defaultdict(list) | ||
for srl_instance in conll_reader.srl_instances(): | ||
words = [word[0] for word in srl_instance.words] | ||
sentence = " ".join(words) | ||
frame_name = srl_instance.verb_stem | ||
if frame_name.lower() not in valid_frames: | ||
continue | ||
trigger_locs = [ | ||
conll_word_index_to_locs(words, index)[0] for index in srl_instance.verb | ||
] | ||
|
||
frame_elements = [] | ||
for argument in srl_instance.arguments: | ||
words_range, frame_element_name = argument | ||
element_start_loc = conll_word_index_to_locs(words, words_range[0])[0] | ||
element_end_loc = conll_word_index_to_locs(words, words_range[1] - 1)[1] | ||
frame_elements.append( | ||
FrameElementAnnotation( | ||
frame_element_name, element_start_loc, element_end_loc | ||
) | ||
) | ||
sents_map[sentence].append( | ||
FrameAnnotation(frame_name, trigger_locs, frame_elements) | ||
) | ||
|
||
for sentence, frame_annotations in sents_map.items(): | ||
annotated_sentences.append( | ||
FrameAnnotatedSentence(sentence, frame_annotations) | ||
) | ||
return annotated_sentences | ||
|
||
|
||
class Propbank34TrainingLoader(TrainingLoader): | ||
""" | ||
This loader uses ontonotes and ewt data from propbank 3.1 to train a model | ||
You must clone https://github.com/propbank/propbank-release and set the propbank_release_dir to the path of the cloned repo | ||
You must also download the LDC data for ontonotes and ewt, and run map_all_to_conll.py as described in the propbank repo | ||
Sadly, this data isn't free so you'll need to get it yourself before working with this loader. | ||
""" | ||
|
||
propbank_release_dir: str | ||
train_docs: list[str] = [] | ||
val_docs: list[str] = [] | ||
test_docs: list[str] = [] | ||
valid_frames: set[str] = set() | ||
|
||
def __init__(self, propbank_release_dir: str) -> None: | ||
super().__init__() | ||
self.propbank_release_dir = propbank_release_dir | ||
|
||
def setup(self) -> None: | ||
self.valid_frames = {frame.name.lower() for frame in load_propbank_frames()} | ||
self.train_docs = load_docs_set(self.propbank_release_dir, SPLITS["train"]) | ||
self.val_docs = load_docs_set(self.propbank_release_dir, SPLITS["val"]) | ||
self.test_docs = load_docs_set(self.propbank_release_dir, SPLITS["test"]) | ||
|
||
def get_augmentations(self) -> list[DataAugmentation]: | ||
return [ | ||
RemoveEndPunctuationAugmentation(0.3), | ||
LowercaseAugmentation(0.2), | ||
RemoveContractionsAugmentation(0.2), | ||
] | ||
|
||
def load_training_data(self) -> list[FrameAnnotatedSentence]: | ||
return load_propbank_samples(self.train_docs, self.valid_frames) | ||
|
||
def load_test_data(self) -> list[FrameAnnotatedSentence]: | ||
return load_propbank_samples(self.test_docs, self.valid_frames) | ||
|
||
def load_validation_data(self) -> list[FrameAnnotatedSentence]: | ||
return load_propbank_samples(self.val_docs, self.valid_frames) |
8 changes: 8 additions & 0 deletions
8
frame_semantic_transformer/data/loaders/propbank34/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .Propbank34InferenceLoader import Propbank34InferenceLoader | ||
from .Propbank34TrainingLoader import Propbank34TrainingLoader | ||
|
||
|
||
__all__ = [ | ||
"Propbank34InferenceLoader", | ||
"Propbank34TrainingLoader", | ||
] |
21 changes: 21 additions & 0 deletions
21
frame_semantic_transformer/data/loaders/propbank34/ensure_propbank_downloaded.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import nltk | ||
from nltk.downloader import Package | ||
|
||
|
||
# NLTK only has v1.0 of PropBank, so hackily create a NLTK package and download v3.1 | ||
propbank34 = Package( | ||
id="propbank-frames-3.4.0", | ||
url="https://github.com/propbank/propbank-frames/archive/refs/tags/v3.4.0.zip", | ||
name="Proposition Bank Corpus 3.4", | ||
checksum="e563f8c9912d53ed7e709455746875e5", | ||
subdir="corpora", | ||
size=9484561, | ||
unzipped_size=29870379, | ||
) | ||
|
||
|
||
def ensure_propbank_downloaded() -> None: | ||
try: | ||
nltk.data.find("corpora/propbank-frames-3.4.0.zip") | ||
except LookupError: | ||
nltk.download(propbank34) |
37 changes: 37 additions & 0 deletions
37
frame_semantic_transformer/data/loaders/propbank34/load_propbank_frames.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from __future__ import annotations | ||
|
||
import nltk | ||
from glob import glob | ||
from os import path | ||
from xml.etree import ElementTree | ||
|
||
|
||
from frame_semantic_transformer.data.frame_types import Frame | ||
|
||
|
||
def load_propbank_frames() -> list[Frame]: | ||
""" | ||
Load the full list of frames to be used during inference | ||
""" | ||
dataset_path = nltk.data.find("corpora/propbank-frames-3.4.0").path | ||
frames_paths = glob(path.join(dataset_path, "frames", "*.xml")) | ||
frames = [] | ||
for frame_path in frames_paths: | ||
with open(frame_path, "r") as frame_file: | ||
etree = ElementTree.parse(frame_file).getroot() | ||
raw_frames = etree.findall("predicate/roleset") | ||
for raw_frame in raw_frames: | ||
frame = Frame( | ||
name=raw_frame.attrib["id"], | ||
core_elements=[ | ||
f"ARG{role.attrib['n']}-{role.attrib['f']}" | ||
for role in raw_frame.findall("roles/role") | ||
], | ||
non_core_elements=[], | ||
lexical_units=[ | ||
f"{alias.text}.{alias.attrib['pos']}" | ||
for alias in raw_frame.findall("aliases/alias") | ||
], | ||
) | ||
frames.append(frame) | ||
return frames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.