diff --git a/spacy_experimental/coref/coref_component.py b/spacy_experimental/coref/coref_component.py index 1e6788d..38ee7c1 100644 --- a/spacy_experimental/coref/coref_component.py +++ b/spacy_experimental/coref/coref_component.py @@ -145,6 +145,11 @@ def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """ out = [] for doc in docs: + if len(doc) < 2: + # no coref in docs with 0 or 1 token + out.append([]) + continue + scores, idxs = self.model.predict([doc]) # idxs is a list of mentions (start / end idxs) # each item in scores includes scores and a mapping from scores to mentions @@ -232,6 +237,9 @@ def update( predicted docs in coref training. """ ) + if len(eg.predicted) < 2: + # no prediction possible for docs of length 0 or 1 + continue preds, backprop = self.model.begin_update([eg.predicted]) score_matrix, mention_idx = preds loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) diff --git a/spacy_experimental/coref/tests/test_coref.py b/spacy_experimental/coref/tests/test_coref.py index 6ff07d4..31caf9a 100644 --- a/spacy_experimental/coref/tests/test_coref.py +++ b/spacy_experimental/coref/tests/test_coref.py @@ -37,6 +37,11 @@ def generate_train_data(prefix=DEFAULT_CLUSTER_PREFIX): } }, ), + ( + # example short doc + "ok", + {"spans": {}} + ) ] # fmt: on return data @@ -83,11 +88,12 @@ def test_initialized(nlp): def test_initialized_short(nlp): + # test that short or empty docs don't fail nlp.add_pipe("experimental_coref") nlp.initialize() assert nlp.pipe_names == ["experimental_coref"] - text = "Hi there" - doc = nlp(text) + doc = nlp("Hi") + doc = nlp("") def test_coref_serialization(nlp): @@ -148,7 +154,8 @@ def test_overfitting_IO(nlp, train_data): def test_tokenization_mismatch(nlp, train_data): train_examples = [] - for text, annot in train_data: + # this is testing a specific test example, so just get the first doc + for text, annot in train_data[0:1]: eg = Example.from_dict(nlp.make_doc(text), annot) ref = eg.reference char_spans = {} diff --git a/spacy_experimental/coref/tests/test_span_resolver.py b/spacy_experimental/coref/tests/test_span_resolver.py index d2d829d..e0238ee 100644 --- a/spacy_experimental/coref/tests/test_span_resolver.py +++ b/spacy_experimental/coref/tests/test_span_resolver.py @@ -79,6 +79,13 @@ def test_not_initialized(nlp): with pytest.raises(ValueError, match="E109"): nlp(text) +def test_initialized_short(nlp): + # docs with one or no tokens should not fail + nlp.add_pipe("experimental_span_resolver") + nlp.initialize() + assert nlp.pipe_names == ["experimental_span_resolver"] + nlp("hi") + nlp("") def test_span_resolver_serialization(nlp): # Test that the span resolver component can be serialized