Skip to content

Commit

Permalink
adding evaluate / predict helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 6, 2022
1 parent 081533b commit be08ec1
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
38 changes: 38 additions & 0 deletions frame_semantic_transformer/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import Iterable
from transformers import T5ForConditionalGeneration, T5Tokenizer
from nltk.corpus import framenet as fn

from frame_semantic_transformer.data.SampleSentence import SampleSentence
from frame_semantic_transformer.predict import predict


all_valid_frames = {frame.name for frame in fn.frames()}


def evaluate(
model: T5ForConditionalGeneration,
tokenizer: T5Tokenizer,
samples: Iterable[SampleSentence],
) -> dict[str, list[int]]:
results: dict[str, list[int]] = {"frame": [0, 0, 0], "args": [0, 0, 0]}
for sample in samples:
frame_task_input = sample.frame_classification_input
args_task_input = sample.frame_args_input

frame_prediction = predict(model, tokenizer, frame_task_input)
args_prediction = predict(model, tokenizer, args_task_input)

if frame_prediction == sample.frame:
results["frame"][0] += 1
elif frame_prediction in all_valid_frames:
results["frame"][1] += 1
else:
results["frame"][2] += 1

if args_prediction == sample.frame_elements_str:
results["args"][0] += 1
# TODO: figure out fp/fn for frame elements
else:
results["args"][1] += 1
return results
43 changes: 43 additions & 0 deletions frame_semantic_transformer/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations
from transformers import T5Tokenizer, T5ForConditionalGeneration


def predict(
model: T5ForConditionalGeneration,
tokenizer: T5Tokenizer,
source_text: str,
max_length: int = 512,
num_return_sequences: int = 1,
num_beams: int = 5,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 2.5,
length_penalty: float = 1.0,
early_stopping: bool = True,
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
) -> list[str]:
input_ids = tokenizer.encode(
source_text, return_tensors="pt", add_special_tokens=True
)
input_ids = input_ids.to(model.device)
generated_ids = model.generate(
input_ids=input_ids,
num_beams=num_beams,
max_length=max_length,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
)
preds = [
tokenizer.decode(
g,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
for g in generated_ids
]
return preds

0 comments on commit be08ec1

Please sign in to comment.