Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with resolving final token in SpanResolver #27

Merged
merged 3 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spacy_experimental/coref/pytorch_span_resolver_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def forward(
Returns:
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
"""

# If we don't receive heads, return empty
device = heads_ids.device
if heads_ids.nelement() == 0:
Expand Down
13 changes: 9 additions & 4 deletions spacy_experimental/coref/span_resolver_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
"""
for doc, clusters in zip(docs, clusters_by_doc):
for ii, cluster in enumerate(clusters, 1):
spans = [doc[int(mm[0]) : int(mm[1])] for mm in cluster]
# Note the +1, since model end indices are inclusive
spans = [doc[int(mm[0]) : int(mm[1]) + 1] for mm in cluster]
doc.spans[f"{self.output_prefix}_{ii}"] = spans

def update(
Expand Down Expand Up @@ -274,9 +275,12 @@ def get_loss(

# NOTE This is doing fake batching, and should always get a list of one example
assert len(list(examples)) == 1, "Only fake batching is supported."
# starts and ends are gold starts and ends (Ints1d)
# span_scores is a Floats3d. What are the axes? mention x token x start/end

# NOTE Within this component, end token indices are *inclusive*. This
# is different than normal Python/spaCy representations, but has the
# advantage that the set of possible start and end indices is the same.
for eg in examples:
# starts and ends are gold starts and ends (Ints1d)
starts = []
ends = []
keeps = []
Expand All @@ -296,11 +300,12 @@ def get_loss(
)
continue
starts.append(span.start)
ends.append(span.end)
ends.append(span.end - 1)
keeps.append(sidx - 1)

starts_xp = self.model.ops.xp.asarray(starts)
ends_xp = self.model.ops.xp.asarray(ends)
# span_scores is a Floats3d. Axes: mention x token x start/end
start_scores = span_scores[:, :, 0][keeps]
end_scores = span_scores[:, :, 1][keeps]

Expand Down
2 changes: 1 addition & 1 deletion spacy_experimental/coref/tests/test_span_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def generate_train_data(
# fmt: off
data = [
(
"John Smith picked up the red ball and he threw it away.",
"John Smith picked up the red ball and he threw it",
{
"spans": {
f"{output_prefix}_1": [
Expand Down