Skip to content

Commit

Permalink
adding a base FrameSemanticTransformer class to make it easy to parse…
Browse files Browse the repository at this point in the history
… sentences into frames
  • Loading branch information
chanind committed May 17, 2022
1 parent 667f85e commit c0b78cf
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 3 deletions.
168 changes: 168 additions & 0 deletions frame_semantic_transformer/FrameSemanticTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from __future__ import annotations
from dataclasses import dataclass
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from frame_semantic_transformer.data.data_utils import chunk_list, marked_string_to_locs
from frame_semantic_transformer.data.tasks.ArgumentsExtractionTask import (
ArgumentsExtractionTask,
)
from frame_semantic_transformer.data.tasks.FrameClassificationTask import (
FrameClassificationTask,
)

from frame_semantic_transformer.data.tasks.TriggerIdentificationTask import (
TriggerIdentificationTask,
)
from frame_semantic_transformer.predict import batch_predict


OFFICIAL_RELEASES = ["base", "small", "large"]


@dataclass
class FrameElementResult:
name: str
text: str


@dataclass
class FrameResult:
name: str
trigger_location: int
frame_elements: list[FrameElementResult]


@dataclass
class DetectFramesResult:
sentence: str
trigger_locations: list[int]
frames: list[FrameResult]


class FrameSemanticTransformer:

model: T5ForConditionalGeneration
tokenizer: T5Tokenizer
device: torch.device
max_batch_size: int
predictions_per_sample: int

def __init__(
self,
model_name_or_path: str = "base",
use_gpu: bool = torch.cuda.is_available(),
max_batch_size: int = 8,
predictions_per_sample: int = 5,
):
model_path = model_name_or_path
if model_name_or_path in OFFICIAL_RELEASES:
model_path = f"chanind/frame-semantic-transformer-{model_name_or_path}"
self.device = torch.device("cuda" if use_gpu else "cpu")
self.model = T5ForConditionalGeneration.from_pretrained(model_path).to(
self.device
)
self.tokenizer = T5Tokenizer.from_pretrained(model_path)
self.max_batch_size = max_batch_size
self.predictions_per_sample = predictions_per_sample

def _batch_predict(self, inputs: list[str]) -> list[str]:
"""
helper to avoid needing to repeatedly pass in the same params every call to predict
"""
return batch_predict(
self.model,
self.tokenizer,
inputs,
num_beams=self.predictions_per_sample,
num_return_sequences=self.predictions_per_sample,
)

def _identify_triggers(self, sentence: str) -> tuple[str, list[int]]:
task = TriggerIdentificationTask(text=sentence)
outputs = self._batch_predict([task.get_input()])
result = task.parse_output(outputs)
return marked_string_to_locs(result)

def _classify_frames(
self, sentence: str, trigger_locs: list[int]
) -> list[str | None]:
"""
Return a list containing a frame for each trigger_loc passed in.
If no frame can be found, None is returned for the frame instead.
"""
frame_classification_tasks: list[FrameClassificationTask] = []
frames: list[str | None] = []

for trigger_loc in trigger_locs:
frame_classification_tasks.append(
FrameClassificationTask(text=sentence, trigger_loc=trigger_loc)
)
for batch in chunk_list(
frame_classification_tasks, chunk_size=self.max_batch_size
):
batch_results = self._batch_predict([task.get_input() for task in batch])
for preds, frame_task in zip(
chunk_list(batch_results, self.predictions_per_sample),
batch,
):
frames.append(frame_task.parse_output(preds))
return frames

def _extract_frame_args(
self, sentence: str, frame_with_trigger_locs: list[tuple[str, int]]
) -> list[list[tuple[str, str]]]:
"""
return a list of tuples of (frame_element, text) for each frame/trigger loc passed in.
The returned list will have the same length as the frame_with_trigger_locs list,
with each element corresponding to a frame/loc in the input list
"""
frame_element_results: list[list[tuple[str, str]]] = []
arg_extraction_tasks = [
ArgumentsExtractionTask(
text=sentence,
trigger_loc=trigger_loc,
frame=frame,
)
for frame, trigger_loc in frame_with_trigger_locs
]
for args_tasks_batch in chunk_list(
arg_extraction_tasks, chunk_size=self.max_batch_size
):
batch_results = self._batch_predict(
[task.get_input() for task in args_tasks_batch],
)
for preds, args_task in zip(
chunk_list(batch_results, self.predictions_per_sample), args_tasks_batch
):
frame_element_results.append(args_task.parse_output(preds))
return frame_element_results

def detect_frames(self, sentence: str) -> DetectFramesResult:
# first detect trigger locations
base_sentence, trigger_locs = self._identify_triggers(sentence)
# next detect frames for each trigger
frames = self._classify_frames(base_sentence, trigger_locs)

frame_and_locs = [
(frame, loc) for frame, loc in zip(frames, trigger_locs) if frame
]
frame_elements_lists = self._extract_frame_args(base_sentence, frame_and_locs)
frame_results: list[FrameResult] = []
for ((frame, loc), frame_element_tuples) in zip(
frame_and_locs, frame_elements_lists
):
frame_elements = [
FrameElement(element, text) for element, text in frame_element_tuples
]
frame_results.append(
FrameResult(
name=frame,
trigger_location=loc,
frame_elements=frame_elements,
)
)
return DetectFramesResult(
base_sentence,
trigger_locations=trigger_locs,
frames=frame_results,
)
14 changes: 14 additions & 0 deletions frame_semantic_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
__version__ = "0.1.0"

from .FrameSemanticTransformer import (
FrameSemanticTransformer,
DetectFramesResult,
FrameElementResult,
FrameResult,
)

__all__ = (
"FrameSemanticTransformer",
"DetectFramesResult",
"FrameElementResult",
"FrameResult",
)
23 changes: 23 additions & 0 deletions frame_semantic_transformer/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,26 @@ def standardize_punct(sent: str) -> str:
updated_sent = re.sub(r"\*([a-zA-Z0-9])", r"* \1", updated_sent)

return updated_sent.strip()


def marked_string_to_locs(
text: str, symbol: str = "*", remove_spaces: bool = True
) -> tuple[str, list[int]]:
"""
Take a string like "He * went to the * store" and return the indices of the tagged words,
in this case "went" and "store", and remove the tags (in this case the *'s)
"""
output_str = ""
remaining_str = text
locs: list[int] = []
symbol_index = remaining_str.find("*")

while symbol_index != -1:
locs.append(symbol_index + len(output_str))
output_str += remaining_str[:symbol_index]
remaining_str = remaining_str[symbol_index + len(symbol) :]
if remove_spaces:
remaining_str = remaining_str.strip()
symbol_index = remaining_str.find("*")
output_str += remaining_str
return output_str, locs
4 changes: 2 additions & 2 deletions frame_semantic_transformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def predict_on_ids(
)
preds = [
tokenizer.decode(
g,
generated_id,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for g in generated_ids
for generated_id in generated_ids
]
return preds
26 changes: 25 additions & 1 deletion tests/data/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import annotations

from frame_semantic_transformer.data.data_utils import standardize_punct
import pytest

from frame_semantic_transformer.data.data_utils import (
marked_string_to_locs,
standardize_punct,
)


def test_standardize_punct_removes_spaces_before_punctuation() -> None:
Expand Down Expand Up @@ -65,3 +70,22 @@ def test_standardize_punct_removes_spaces_before_commas() -> None:
"2- * Sheik of Albu'Ubaid ( Salah al-Dhari ), who * slaughtered * thirty sheeps"
)
assert standardize_punct(original) == expected


@pytest.mark.parametrize(
"input,expected",
[
("Hi * there", ("Hi there", [3])),
("Hi there", ("Hi there", [])),
(
"Does Iran * intend to * become a Nuclear State?",
("Does Iran intend to become a Nuclear State?", [10, 20]),
),
(
"Does Iran * intend to *become a Nuclear State?",
("Does Iran intend to become a Nuclear State?", [10, 20]),
),
],
)
def test_marked_string_to_locs(input: str, expected: tuple[str, list[int]]) -> None:
assert marked_string_to_locs(input) == expected

0 comments on commit c0b78cf

Please sign in to comment.