Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
svlandeg committed Mar 27, 2024
1 parent ff88ab3 commit a1cde9d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
16 changes: 9 additions & 7 deletions spacy/pipeline/entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,16 @@ def _score_augmented(examples, **kwargs):
self.scorer = _score_augmented

def _augment_examples(self, examples: Iterable[Example]) -> Iterable[Example]:
"""If use_gold_ents is true, set the gold entities to eg.predicted.
"""
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
if not self.use_gold_ents:
return examples

new_examples = []
for eg in examples:
if self.use_gold_ents:
ents, _ = eg.get_aligned_ents_and_ner()
eg.predicted.ents = ents
new_examples.append(eg)
ents, _ = eg.get_aligned_ents_and_ner()
new_eg = eg.copy()
new_eg.predicted.ents = ents
new_examples.append(new_eg)
return new_examples

def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
Expand Down Expand Up @@ -399,7 +401,7 @@ def update(
return losses

def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
""" Here, we assume that get_loss is called with augmented examples if need be"""
"""Here, we assume that get_loss is called with augmented examples if need be"""
validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = []
eidx = 0 # indices in gold entities to keep
Expand Down
8 changes: 6 additions & 2 deletions spacy/tests/pipeline/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,9 @@ def create_kb(vocab):
return mykb

# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True, config={"use_gold_ents": True})
entity_linker = nlp.add_pipe(
"entity_linker", last=True, config={"use_gold_ents": True}
)
assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
Expand Down Expand Up @@ -849,7 +851,9 @@ def create_kb(vocab):

# Create the NER and EL components and add them to the pipeline
ner = nlp.add_pipe("ner", first=True)
entity_linker = nlp.add_pipe("entity_linker", last=True, config={"use_gold_ents": False})
entity_linker = nlp.add_pipe(
"entity_linker", last=True, config={"use_gold_ents": False}
)
entity_linker.set_kb(create_kb)

train_examples = []
Expand Down

0 comments on commit a1cde9d

Please sign in to comment.