From d066b9602741f84fa5b56acf11db1d8b32987b0e Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 24 Oct 2022 16:14:20 +0900 Subject: [PATCH 1/3] Fix issue with resolving final token in SpanResolver The SpanResolver seems unable to include the final token in a Doc in output spans. It will even produce empty spans instead of doing so. This makes changes so that within the model span end indices are treated as inclusive, and converts them back to exclusive when annotating docs. This has been tested to work, though an automated test should be added. --- spacy_experimental/coref/span_resolver_component.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_experimental/coref/span_resolver_component.py b/spacy_experimental/coref/span_resolver_component.py index e3f2336..4091182 100644 --- a/spacy_experimental/coref/span_resolver_component.py +++ b/spacy_experimental/coref/span_resolver_component.py @@ -176,7 +176,7 @@ def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: """ for doc, clusters in zip(docs, clusters_by_doc): for ii, cluster in enumerate(clusters, 1): - spans = [doc[int(mm[0]) : int(mm[1])] for mm in cluster] + spans = [doc[int(mm[0]) : int(mm[1]) + 1] for mm in cluster] doc.spans[f"{self.output_prefix}_{ii}"] = spans def update( @@ -296,7 +296,7 @@ def get_loss( ) continue starts.append(span.start) - ends.append(span.end) + ends.append(span.end - 1) keeps.append(sidx - 1) starts_xp = self.model.ops.xp.asarray(starts) From 361ba62c8239330865d19d8918eef712c2f29a63 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 24 Oct 2022 18:54:44 +0900 Subject: [PATCH 2/3] Modify tests so last token is in a mention Running the modify tests without the changes from the previous commit, they fail. This demonstrates and clarifies the bug. --- spacy_experimental/coref/tests/test_span_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_experimental/coref/tests/test_span_resolver.py b/spacy_experimental/coref/tests/test_span_resolver.py index d2d829d..02b19b2 100644 --- a/spacy_experimental/coref/tests/test_span_resolver.py +++ b/spacy_experimental/coref/tests/test_span_resolver.py @@ -22,7 +22,7 @@ def generate_train_data( # fmt: off data = [ ( - "John Smith picked up the red ball and he threw it away.", + "John Smith picked up the red ball and he threw it", { "spans": { f"{output_prefix}_1": [ From 0fa4590677d5dbf8e958ee2855a2670ab3959681 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 24 Oct 2022 19:05:43 +0900 Subject: [PATCH 3/3] Add / rearrange comments --- spacy_experimental/coref/pytorch_span_resolver_model.py | 1 + spacy_experimental/coref/span_resolver_component.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/spacy_experimental/coref/pytorch_span_resolver_model.py b/spacy_experimental/coref/pytorch_span_resolver_model.py index 5c308e6..84b4b2d 100644 --- a/spacy_experimental/coref/pytorch_span_resolver_model.py +++ b/spacy_experimental/coref/pytorch_span_resolver_model.py @@ -53,6 +53,7 @@ def forward( Returns: torch.Tensor: span start/end scores, (n_heads x n_words x 2) """ + # If we don't receive heads, return empty device = heads_ids.device if heads_ids.nelement() == 0: diff --git a/spacy_experimental/coref/span_resolver_component.py b/spacy_experimental/coref/span_resolver_component.py index 4091182..9b2bc23 100644 --- a/spacy_experimental/coref/span_resolver_component.py +++ b/spacy_experimental/coref/span_resolver_component.py @@ -176,6 +176,7 @@ def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: """ for doc, clusters in zip(docs, clusters_by_doc): for ii, cluster in enumerate(clusters, 1): + # Note the +1, since model end indices are inclusive spans = [doc[int(mm[0]) : int(mm[1]) + 1] for mm in cluster] doc.spans[f"{self.output_prefix}_{ii}"] = spans @@ -274,9 +275,12 @@ def get_loss( # NOTE This is doing fake batching, and should always get a list of one example assert len(list(examples)) == 1, "Only fake batching is supported." - # starts and ends are gold starts and ends (Ints1d) - # span_scores is a Floats3d. What are the axes? mention x token x start/end + + # NOTE Within this component, end token indices are *inclusive*. This + # is different than normal Python/spaCy representations, but has the + # advantage that the set of possible start and end indices is the same. for eg in examples: + # starts and ends are gold starts and ends (Ints1d) starts = [] ends = [] keeps = [] @@ -301,6 +305,7 @@ def get_loss( starts_xp = self.model.ops.xp.asarray(starts) ends_xp = self.model.ops.xp.asarray(ends) + # span_scores is a Floats3d. Axes: mention x token x start/end start_scores = span_scores[:, :, 0][keeps] end_scores = span_scores[:, :, 1][keeps]