Skip to content

Commit

Permalink
Add spancat_singlelabel to debug data CLI (explosion#12749)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianeboyd authored Jun 26, 2023
1 parent cb4fdc8 commit e166421
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
6 changes: 3 additions & 3 deletions spacy/cli/debug_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def debug_data(
else:
msg.info("No word vectors present in the package")

if "spancat" in factory_names:
if "spancat" in factory_names or "spancat_singlelabel" in factory_names:
model_labels_spancat = _get_labels_from_spancat(nlp)
has_low_data_warning = False
has_no_neg_warning = False
Expand Down Expand Up @@ -848,7 +848,7 @@ def _compile_gold(
data["boundary_cross_ents"] += 1
elif label == "-":
data["ner"]["-"] += 1
if "spancat" in factory_names:
if "spancat" in factory_names or "spancat_singlelabel" in factory_names:
for spans_key in list(eg.reference.spans.keys()):
# Obtain the span frequency
if spans_key not in data["spancat"]:
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def _get_labels_from_spancat(nlp: Language) -> Dict[str, Set[str]]:
pipe_names = [
pipe_name
for pipe_name in nlp.pipe_names
if nlp.get_pipe_meta(pipe_name).factory == "spancat"
if nlp.get_pipe_meta(pipe_name).factory in ("spancat", "spancat_singlelabel")
]
labels: Dict[str, Set[str]] = {}
for pipe_name in pipe_names:
Expand Down
5 changes: 3 additions & 2 deletions spacy/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,8 @@ def test_debug_data_compile_gold():
assert data["boundary_cross_ents"] == 1


def test_debug_data_compile_gold_for_spans():
@pytest.mark.parametrize("component_name", ["spancat", "spancat_singlelabel"])
def test_debug_data_compile_gold_for_spans(component_name):
nlp = English()
spans_key = "sc"

Expand All @@ -870,7 +871,7 @@ def test_debug_data_compile_gold_for_spans():
ref.spans[spans_key] = [Span(ref, 3, 6, "ORG"), Span(ref, 5, 6, "GPE")]
eg = Example(pred, ref)

data = _compile_gold([eg], ["spancat"], nlp, True)
data = _compile_gold([eg], [component_name], nlp, True)

assert data["spancat"][spans_key] == Counter({"ORG": 1, "GPE": 1})
assert data["spans_length"][spans_key] == {"ORG": [3], "GPE": [1]}
Expand Down

0 comments on commit e166421

Please sign in to comment.