diff --git a/spacy_experimental/coref/pytorch_span_resolver_model.py b/spacy_experimental/coref/pytorch_span_resolver_model.py index 5c308e6..84b4b2d 100644 --- a/spacy_experimental/coref/pytorch_span_resolver_model.py +++ b/spacy_experimental/coref/pytorch_span_resolver_model.py @@ -53,6 +53,7 @@ def forward( Returns: torch.Tensor: span start/end scores, (n_heads x n_words x 2) """ + # If we don't receive heads, return empty device = heads_ids.device if heads_ids.nelement() == 0: diff --git a/spacy_experimental/coref/span_resolver_component.py b/spacy_experimental/coref/span_resolver_component.py index e3f2336..9b2bc23 100644 --- a/spacy_experimental/coref/span_resolver_component.py +++ b/spacy_experimental/coref/span_resolver_component.py @@ -176,7 +176,8 @@ def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: """ for doc, clusters in zip(docs, clusters_by_doc): for ii, cluster in enumerate(clusters, 1): - spans = [doc[int(mm[0]) : int(mm[1])] for mm in cluster] + # Note the +1, since model end indices are inclusive + spans = [doc[int(mm[0]) : int(mm[1]) + 1] for mm in cluster] doc.spans[f"{self.output_prefix}_{ii}"] = spans def update( @@ -274,9 +275,12 @@ def get_loss( # NOTE This is doing fake batching, and should always get a list of one example assert len(list(examples)) == 1, "Only fake batching is supported." - # starts and ends are gold starts and ends (Ints1d) - # span_scores is a Floats3d. What are the axes? mention x token x start/end + + # NOTE Within this component, end token indices are *inclusive*. This + # is different than normal Python/spaCy representations, but has the + # advantage that the set of possible start and end indices is the same. for eg in examples: + # starts and ends are gold starts and ends (Ints1d) starts = [] ends = [] keeps = [] @@ -296,11 +300,12 @@ def get_loss( ) continue starts.append(span.start) - ends.append(span.end) + ends.append(span.end - 1) keeps.append(sidx - 1) starts_xp = self.model.ops.xp.asarray(starts) ends_xp = self.model.ops.xp.asarray(ends) + # span_scores is a Floats3d. Axes: mention x token x start/end start_scores = span_scores[:, :, 0][keeps] end_scores = span_scores[:, :, 1][keeps] diff --git a/spacy_experimental/coref/tests/test_span_resolver.py b/spacy_experimental/coref/tests/test_span_resolver.py index d2d829d..02b19b2 100644 --- a/spacy_experimental/coref/tests/test_span_resolver.py +++ b/spacy_experimental/coref/tests/test_span_resolver.py @@ -22,7 +22,7 @@ def generate_train_data( # fmt: off data = [ ( - "John Smith picked up the red ball and he threw it away.", + "John Smith picked up the red ball and he threw it", { "spans": { f"{output_prefix}_1": [