Skip to content

Commit

Permalink
Correct Keyerror for h-label cls in label_groups for dm_label_categor…
Browse files Browse the repository at this point in the history
…ies using label's id/key (#3932)

Modify label_groups for dm_label_categories with id/key of label
  • Loading branch information
sooahleex authored Sep 5, 2024
1 parent 53a7d9a commit c3749e3
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,32 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
"""Get label tree edges information. Each edges represent [child, parent]."""
return [[item.name, item.parent] for item in dm_label_items if item.parent != ""]

all_groups = [label_group.labels for label_group in dm_label_categories.label_groups]
def convert_labels_if_needed(
dm_label_categories: LabelCategories,
label_names: list[str],
) -> list[list[str]]:
# Check if the labels need conversion and create name to ID mapping if required
name_to_id_mapping = None
for label_group in dm_label_categories.label_groups:
if label_group.labels and label_group.labels[0] not in label_names:
name_to_id_mapping = {
attr[len("__name__") :]: category.name
for category in dm_label_categories.items
for attr in category.attributes
if attr.startswith("__name__")
}
break

# If mapping exists, update the labels
if name_to_id_mapping:
for label_group in dm_label_categories.label_groups:
label_group.labels = [name_to_id_mapping.get(label, label) for label in label_group.labels]

# Retrieve all label groups after conversion
return [group.labels for group in dm_label_categories.label_groups]

label_names = [item.name for item in dm_label_categories.items]
all_groups = convert_labels_if_needed(dm_label_categories, label_names)

exclusive_group_info = get_exclusive_group_info(all_groups)
single_label_group_info = get_single_label_group_info(all_groups, exclusive_group_info["num_multiclass_heads"])
Expand All @@ -240,7 +265,7 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
)

return HLabelInfo(
label_names=[item.name for item in dm_label_categories.items],
label_names=label_names,
label_groups=all_groups,
num_multiclass_heads=exclusive_group_info["num_multiclass_heads"],
num_multilabel_classes=single_label_group_info["num_multilabel_classes"],
Expand Down

0 comments on commit c3749e3

Please sign in to comment.