Skip to content

Commit

Permalink
Changed attribute definitions and adjust thresholds.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benteng Ma committed Jun 17, 2024
1 parent 122784f commit a0b3b15
Showing 1 changed file with 18 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,89 +159,46 @@ class CelebAMaskHQCategoriesAndAttributes(CategoriesAndAttributes):
for key in sorted(merged_categories.keys()):
thresholds_mask[key] = 0.5
for key in attributes + mask_labels:
thresholds_pred[key] = 0.5
if key not in avoided_attributes:
thresholds_pred[key] = 0.5

# set specific thresholds:
thresholds_mask["eye_g"] = 0.25
thresholds_pred["Eyeglasses"] = 0.25
thresholds_mask["eye_g"] = 0.5
thresholds_pred["Eyeglasses"] = 0.5
thresholds_pred["Wearing_Earrings"] = 0.5
thresholds_pred["Wearing_Necklace"] = 0.5
thresholds_pred["Wearing_Necktie"] = 0.5


class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
mask_categories = [
"top",
"down",
"outwear",
"dress",
"short sleeve top",
"long sleeve top",
"short sleeve outwear",
"long sleeve outwear",
"vest",
"sling",
"shorts",
"trousers",
"skirt",
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress",
'short sleeve top', 'long sleeve top', 'short sleeve outwear',
'long sleeve outwear', 'vest', 'sling', 'shorts',
'trousers', 'skirt', 'short sleeve dress',
'long sleeve dress', 'vest dress', 'sling dress'
]
merged_categories = {
"top": [
"short sleeve top",
"long sleeve top",
"vest",
"sling",
],
"down": [
"shorts",
"trousers",
"skirt",
],
"outwear": [
"short sleeve outwear",
"long sleeve outwear",
],
"dress": [
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress",
],
'top': ['short sleeve top', 'long sleeve top', 'vest', 'sling', ],
'down': ['shorts', 'trousers', 'skirt', ],
'outwear': ['short sleeve outwear', 'long sleeve outwear', ],
'dress': ['short sleeve dress', 'long sleeve dress', 'vest dress', 'sling dress', ],
}
mask_labels = ['top', 'down', 'outwear', 'dress', ]
_categories_to_merge = []
for key in sorted(list(merged_categories.keys())):
for cat in merged_categories[key]:
_categories_to_merge.append(cat)
for key in mask_categories:
if key not in _categories_to_merge:
merged_categories[key] = [key]
mask_labels = [
"top",
"down",
"outwear",
"dress",
]
selective_attributes = {}
plane_attributes = []
avoided_attributes = []
attributes = [
"short sleeve top",
"long sleeve top",
"short sleeve outwear",
"long sleeve outwear",
"vest",
"sling",
"shorts",
"trousers",
"skirt",
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress",
'short sleeve top', 'long sleeve top', 'short sleeve outwear',
'long sleeve outwear', 'vest', 'sling', 'shorts',
'trousers', 'skirt', 'short sleeve dress',
'long sleeve dress', 'vest dress', 'sling dress'
]

thresholds_mask: dict[str, float] = {}
Expand All @@ -250,7 +207,6 @@ class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
# set default thresholds:
for key in sorted(merged_categories.keys()):
thresholds_mask[key] = 0.5
for key in sorted(mask_categories):
thresholds_mask[key] = 0.5
for key in attributes + mask_labels:
thresholds_pred[key] = 0.5
pass

0 comments on commit a0b3b15

Please sign in to comment.