Skip to content

Commit

Permalink
Add ClothPredictor class, simplified reading the thresholds.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benteng Ma committed Jun 17, 2024
1 parent 12ac8c3 commit 122784f
Showing 1 changed file with 46 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,19 +347,19 @@ def __init__(
self._thresholds_mask: list[float] = []
self._thresholds_pred: list[float] = []
for key in sorted(
list(self.categories_and_attributes.merged_categories.keys())
list(self.categories_and_attributes.thresholds_mask.keys())
):
self._thresholds_mask.append(
self.categories_and_attributes.thresholds_mask[key]
)
for attribute in self.categories_and_attributes.attributes:
if attribute not in self.categories_and_attributes.avoided_attributes:
self._thresholds_pred.append(
self.categories_and_attributes.thresholds_pred[attribute]
)
for key in sorted(
list(self.categories_and_attributes.thresholds_pred.keys())
):
self._thresholds_pred.append(
self.categories_and_attributes.thresholds_pred[key]
)

def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
mean_val = np.mean(rgb_image)
image_tensor = (
torch.from_numpy(rgb_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
)
Expand Down Expand Up @@ -391,6 +391,45 @@ def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
return image_obj


class ClothPredictor(Predictor):
def predict(self, rgb_image: np.ndarray) -> ImageWithMasksAndAttributes:
general_categories = ['top', 'down', 'outwear', 'dress', ]
categories = [
'top', 'down', 'outwear', 'dress',
'short sleeve top', 'long sleeve top', 'short sleeve outwear',
'long sleeve outwear', 'vest', 'sling', 'shorts',
'trousers', 'skirt', 'short sleeve dress',
'long sleeve dress', 'vest dress', 'sling dress'
]
image_tensor = (
torch.from_numpy(rgb_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
)
pred_masks, pred_classes, pred_bboxes = self.model(image_tensor)
# Apply binary erosion and dilation to the masks
pred_masks = binary_erosion_dilation(
pred_masks,
thresholds=self._thresholds_pred,
erosion_iterations=1,
dilation_iterations=1,
)
pred_masks = pred_masks.detach().squeeze(0).numpy().astype(np.uint8)
mask_list = [pred_masks[i, :, :] for i in range(pred_masks.shape[0])]
pred_classes = pred_classes.detach().squeeze(0).numpy()
class_list = [pred_classes[i].item() for i in range(pred_classes.shape[0])]
mask_dict = {}
for i, mask in enumerate(mask_list):
mask_dict[categories[i]] = mask
attribute_dict = {}
class_list_iter = class_list.__iter__()
for attribute in categories:
# if attribute not in self.categories_and_attributes.avoided_attributes:
attribute_dict[attribute] = class_list_iter.__next__()
image_obj = ImageOfCloth(
rgb_image, mask_dict, attribute_dict, self.categories_and_attributes
)
return image_obj


def load_face_classifier_model():
cat_layers = CelebAMaskHQCategoriesAndAttributes.merged_categories.keys().__len__()
segment_model = UNetWithResnetEncoder(num_classes=cat_layers)
Expand Down

0 comments on commit 122784f

Please sign in to comment.