Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raise error when setting overlapping entities as doc.ents to close #2550 #2880

Merged
merged 1 commit into from
Oct 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ class Errors(object):
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
"Span objects, or dicts if set to manual=True.")
E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge")

E098 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A token"
" can only be part of one entity, so make sure the entities you're "
"setting don't overlap.")

@add_codes
class TempErrors(object):
Expand Down
11 changes: 11 additions & 0 deletions spacy/tests/doc/test_add_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ...pipeline import EntityRecognizer
from ..util import get_doc
from ...tokens import Span

import pytest

Expand All @@ -22,3 +23,13 @@ def test_doc_add_entities_set_ents_iob(en_vocab):

doc.ents = [(doc.vocab.strings['WORD'], 0, 2)]
assert [w.ent_iob_ for w in doc] == ['B', 'I', '', '']

def test_add_overlapping_entities(en_vocab):
text = ["Louisiana", "Office", "of", "Conservation"]
doc = get_doc(en_vocab, text)
entity = Span(doc, 0, 4, label=391)
doc.ents = [entity]

new_entity = Span(doc, 0, 1, label=392)
with pytest.raises(ValueError):
doc.ents = list(doc.ents) + [new_entity]
3 changes: 2 additions & 1 deletion spacy/tests/regression/test_issue242.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ def test_issue242(en_tokenizer):
matcher.add('FOOD', None, *patterns)

matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)]
doc.ents += tuple(matches)
match1, match2 = matches
assert match1[1] == 3
assert match1[2] == 5
assert match2[1] == 4
assert match2[2] == 6

doc.ents += tuple([match2])
36 changes: 27 additions & 9 deletions spacy/tokens/doc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -458,22 +458,29 @@ cdef class Doc:
# prediction
# 3. Test basic data-driven ORTH gazetteer
# 4. Test more nuanced date and currency regex

tokens_in_ents = {}
cdef attr_t entity_type
cdef int ent_start, ent_end
for ent_info in ents:
entity_type, ent_start, ent_end = get_entity_info(ent_info)
for token_index in range(ent_start, ent_end):
if token_index in tokens_in_ents.keys():
raise ValueError(Errors.E098.format(
span1=(tokens_in_ents[token_index][0],
tokens_in_ents[token_index][1],
self.vocab.strings[tokens_in_ents[token_index][2]]),
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type)

cdef int i
for i in range(self.length):
self.c[i].ent_type = 0
self.c[i].ent_iob = 0 # Means missing.
cdef attr_t ent_type
cdef int start, end
for ent_info in ents:
if isinstance(ent_info, Span):
ent_id = ent_info.ent_id
ent_type = ent_info.label
start = ent_info.start
end = ent_info.end
elif len(ent_info) == 3:
ent_type, start, end = ent_info
else:
ent_id, ent_type, start, end = ent_info
ent_type, start, end = get_entity_info(ent_info)
if ent_type is None or ent_type < 0:
# Mark as O
for i in range(start, end):
Expand Down Expand Up @@ -1062,3 +1069,14 @@ def fix_attributes(doc, attributes):
attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']]
if 'ent_type' in attributes:
attributes[ENT_TYPE] = attributes['ent_type']

def get_entity_info(ent_info):
if isinstance(ent_info, Span):
ent_type = ent_info.label
start = ent_info.start
end = ent_info.end
elif len(ent_info) == 3:
ent_type, start, end = ent_info
else:
ent_id, ent_type, start, end = ent_info
return ent_type, start, end