Skip to content

Commit

Permalink
Replace "kb_ids" by a constant
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Sep 8, 2022
1 parent ac5b1fd commit b828954
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions spacy/pipeline/entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

ActivationsT = Dict[str, Union[List[Ragged], List[str]]]

KNOWLEDGE_BASE_IDS = "kb_ids"

# See #9050
BACKWARD_OVERWRITE = True

Expand Down Expand Up @@ -426,7 +428,7 @@ def predict(self, docs: Iterable[Doc]) -> ActivationsT:
docs_ents: List[Ragged] = []
docs_scores: List[Ragged] = []
if not docs:
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
if isinstance(docs, Doc):
docs = [docs]
for doc in docs:
Expand Down Expand Up @@ -532,7 +534,7 @@ def predict(self, docs: Iterable[Doc]) -> ActivationsT:
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}

def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of documents, using pre-computed scores.
Expand All @@ -543,7 +545,7 @@ def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> Non
DOCS: https://spacy.io/api/entitylinker#set_annotations
"""
kb_ids = cast(List[str], activations["kb_ids"])
kb_ids = cast(List[str], activations[KNOWLEDGE_BASE_IDS])
count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
Expand All @@ -553,7 +555,7 @@ def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> Non
if self.save_activations:
doc.activations[self.name] = {}
for act_name, acts in activations.items():
if act_name != "kb_ids":
if act_name != KNOWLEDGE_BASE_IDS:
# We only copy activations that are Ragged.
doc.activations[self.name][act_name] = cast(Ragged, acts[j])

Expand Down

0 comments on commit b828954

Please sign in to comment.