From 5209e07f34c905ef9df417143880f097ae08240c Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 27 Apr 2020 10:53:01 -0700 Subject: [PATCH] Even better token annotation (#27) * Handle cases where the annotated token goes beyond the last token. * Adds test for the new behavior --- allennlp_models/rc/common/reader_utils.py | 9 +++++++++ tests/rc/reader_utils_test.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/allennlp_models/rc/common/reader_utils.py b/allennlp_models/rc/common/reader_utils.py index cbca55226..bb10daab4 100644 --- a/allennlp_models/rc/common/reader_utils.py +++ b/allennlp_models/rc/common/reader_utils.py @@ -103,6 +103,15 @@ def char_span_to_token_span( token_offsets[end_index] is None or token_offsets[end_index][1] < character_span[1] ): end_index += 1 + if end_index == len(token_offsets): + # We want a character span that goes beyond the last token. Let's see if this is salvageable. + # We consider this salvageable if the span we're looking for starts before the last token ends. + # In other words, we don't salvage if the whole span comes after the tokens end. + if character_span[0] < token_offsets[-1][1]: + # We also want to make sure we aren't way off. We need to be within 8 characters to salvage. + if character_span[1] - 8 < token_offsets[-1][1]: + end_index -= 1 + if end_index >= len(token_offsets): raise ValueError(f"Character span %r outside the range of the given tokens.") if end_index == start_index and token_offsets[end_index][1] > character_span[1]: diff --git a/tests/rc/reader_utils_test.py b/tests/rc/reader_utils_test.py index 11b7e9877..c2dc5effd 100644 --- a/tests/rc/reader_utils_test.py +++ b/tests/rc/reader_utils_test.py @@ -11,7 +11,13 @@ ([(0, 3), (4, 4), (5, 8)], (0, 8), ((0, 2), False)), ([(0, 3), (4, 4), (5, 8)], (1, 8), ((0, 2), True)), ([(0, 3), (4, 4), (5, 8)], (7, 8), ((2, 2), True)), + ([(0, 3), (4, 4), (5, 8)], (7, 9), ((2, 2), True)), ], ) def test_char_span_to_token_span(token_offsets, character_span, expected_result): assert char_span_to_token_span(token_offsets, character_span) == expected_result + + +def test_char_span_to_token_span_throws(): + with pytest.raises(ValueError): + char_span_to_token_span([(0, 3), (4, 4), (5, 8)], (7, 19))