Skip to content

Commit

Permalink
updated predictor working
Browse files Browse the repository at this point in the history
  • Loading branch information
Benteng Ma committed Jun 5, 2024
1 parent 183415a commit 7968946
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,6 @@ def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
mask_list = [pred_masks[i, :, :] for i in range(pred_masks.shape[0])]
pred_classes = pred_classes.detach().squeeze(0).numpy()
class_list = [pred_classes[i].item() for i in range(pred_classes.shape[0])]
# print(rgb_image)
print(mean_val)
print(pred_classes)
mask_dict = {}
for i, mask in enumerate(mask_list):
mask_dict[self.categories_and_attributes.mask_categories[i]] = mask
Expand Down Expand Up @@ -479,8 +476,8 @@ def predict_frame(

# results from two dictionaries are currently merged but might got separated again in the future if needed.
result = {
'attributes': rst_person['attributes'] + rst_person['description'],
'description': rst_cloth['attributes'] + rst_cloth['description'],
'attributes': {**rst_person['attributes'], **rst_cloth['attributes']},
'description': rst_person['description'] + rst_cloth['description'],
}

result = json.dumps(result, indent=4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class CelebAMaskHQCategoriesAndAttributes(CategoriesAndAttributes):

class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
mask_categories = [
'top', 'down', 'outwear', 'dress',
'short sleeve top', 'long sleeve top', 'short sleeve outwear',
'long sleeve outwear', 'vest', 'sling', 'shorts',
'trousers', 'skirt', 'short sleeve dress',
Expand All @@ -189,7 +190,7 @@ class DeepFashion2GeneralizedCategoriesAndAttributes(CategoriesAndAttributes):
for key in mask_categories:
if key not in _categories_to_merge:
merged_categories[key] = [key]
mask_labels = []
mask_labels = ['top', 'down', 'outwear', 'dress',]
selective_attributes = {}
plane_attributes = []
avoided_attributes = []
Expand Down

0 comments on commit 7968946

Please sign in to comment.