Skip to content

Commit

Permalink
feat: bbox prompt for sam
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Apr 28, 2023
1 parent a39c5bd commit 1fa9cae
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 55 deletions.
135 changes: 107 additions & 28 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
import torchvision.transforms.functional as F

import torch
import torchvision
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms

if torch.__version__[0] == "2":
torchvision.disable_beta_transforms_warning()
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F2

import torchvision.transforms.functional as F

from abc import ABC, abstractmethod
import imgaug as ia
import imgaug.augmenters as iaa
import os
Expand Down Expand Up @@ -407,7 +416,9 @@ def get_transform_seg(
transform_list.append(RandomHorizontalFlipMask())

if not opt.dataaug_no_rotate:
transform_list.append(RandomRotationMask(degrees=0))
transform_list.append(
RandomRotationMask(degrees=0)
) # XXX: degrees is a required placeholder, unused

if opt.dataaug_affine:
raff = RandomAffineMask(degrees=0)
Expand Down Expand Up @@ -443,10 +454,19 @@ class ComposeMask(transforms.Compose):
>>> ])
"""

def __call__(self, img, mask):
def __call__(self, img, mask, bbox=None):
if bbox is None:
w, h = img.size
bbox = np.array([0, 0, w, h]) # sets bbox to full image size
if torch.__version__[0] == "2":
tbbox = datapoints.BoundingBox(
bbox, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=img.size
)
else:
tbbox = bbox # placeholder
for t in self.transforms:
img, mask = t(img, mask)
return img, mask
img, mask, tbbox = t(img, mask, tbbox)
return img, mask, tbbox


class GrayscaleMask(transforms.Grayscale):
Expand All @@ -465,15 +485,19 @@ class GrayscaleMask(transforms.Grayscale):
def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
"""
Args:
img (PIL Image): Image to be converted to grayscale.
Returns:
PIL Image: Randomly grayscaled image.
"""
return F.to_grayscale(img, num_output_channels=self.num_output_channels), mask
return (
F.to_grayscale(img, num_output_channels=self.num_output_channels),
mask,
bbox,
)

def __repr__(self):
return self.__class__.__name__ + "(num_output_channels={0})".format(
Expand All @@ -494,17 +518,27 @@ class ResizeMask(transforms.Resize):
``PIL.Image.BILINEAR``
"""

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
return F.resize(img, self.size, interpolation=self.interpolation), F.resize(
mask, self.size, interpolation=InterpolationMode.NEAREST
)

if torch.__version__[0] == "2":
return (
F.resize(img, self.size, interpolation=self.interpolation),
F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST),
F2.resize(bbox, self.size),
)
else:
return (
F.resize(img, self.size, interpolation=self.interpolation),
F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST),
[],
)


class RandomCropMask(transforms.RandomCrop):
Expand Down Expand Up @@ -543,7 +577,7 @@ class RandomCropMask(transforms.RandomCrop):
"""

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
"""
Args:
img (PIL Image): Image to be cropped.
Expand All @@ -567,7 +601,18 @@ def __call__(self, img, mask):

i, j, h, w = self.get_params(img, self.size)

return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
if torch.__version__[0] == "2":
return (
F.crop(img, i, j, h, w),
F.crop(mask, i, j, h, w),
F2.crop(bbox, i, j, h, w),
)
else:
return (
F.crop(img, i, j, h, w),
F.crop(mask, i, j, h, w),
[],
)


class RandomHorizontalFlipMask(transforms.RandomHorizontalFlip):
Expand All @@ -577,7 +622,7 @@ class RandomHorizontalFlipMask(transforms.RandomHorizontalFlip):
p (float): probability of the image being flipped. Default value is 0.5
"""

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
"""
Args:
img (PIL Image): Image to be flipped.
Expand All @@ -586,8 +631,11 @@ def __call__(self, img, mask):
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
return F.hflip(img), F.hflip(mask)
return img, mask
if torch.__version__[0] == "2":
return F.hflip(img), F.hflip(mask), F2.hflip(bbox)
else:
return F.hflip(img), F.hflip(mask), []
return img, mask, bbox


class ToTensorMask(transforms.ToTensor):
Expand All @@ -601,17 +649,22 @@ class ToTensorMask(transforms.ToTensor):
In the other cases, tensors are returned without scaling.
"""

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if torch.__version__[0] == "2":
bbdata = bbox.data
else:
bbdata = bbox
return (
F.to_tensor(img),
torch.from_numpy(np.array(mask, dtype=np.int64)).unsqueeze(0),
bbdata,
)


Expand Down Expand Up @@ -650,7 +703,7 @@ def get_params(degrees):

return angle

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
"""
Args:
img (PIL Image): Image to be rotated.
Expand All @@ -660,7 +713,18 @@ def __call__(self, img, mask):
"""
angle = random.choice([0, 90, 180, 270])

return F.rotate(img, angle), F.rotate(mask, angle, fill=(0,))
if torch.__version__[0] == "2":
return (
F.rotate(img, angle),
F.rotate(mask, angle, fill=(0,)),
F2.rotate(bbox, angle),
)
else:
return (
F.rotate(img, angle),
F.rotate(mask, angle, fill=(0,)),
[],
)


class NormalizeMask(transforms.Normalize):
Expand All @@ -679,15 +743,19 @@ class NormalizeMask(transforms.Normalize):
"""

def __call__(self, tensor_img, tensor_mask):
def __call__(self, tensor_img, tensor_mask, tensor_bbox):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor_img, self.mean, self.std, self.inplace), tensor_mask
return (
F.normalize(tensor_img, self.mean, self.std, self.inplace),
tensor_mask,
tensor_bbox,
)

def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
Expand All @@ -705,7 +773,7 @@ def set_params(self, p, translate, scale_min, scale_max, shear):
self.scale_max = scale_max
self.shear = shear

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):

if random.random() > 1.0 - self.p:
affine_params = self.get_params(
Expand All @@ -715,9 +783,20 @@ def __call__(self, img, mask):
(-self.shear, self.shear),
img.size,
)
return F.affine(img, *affine_params), F.affine(mask, *affine_params)
if torch.__version__[0] == "2":
return (
F.affine(img, *affine_params),
F.affine(mask, *affine_params),
F2.affine(bbox, *affine_params),
)
else:
return (
F.affine(img, *affine_params),
F.affine(mask, *affine_params),
[],
)
else:
return img, mask
return img, mask, bbox


def sometimes(aug):
Expand Down Expand Up @@ -801,11 +880,11 @@ def __init__(self, with_mask=True):
random_order=True,
)

def __call__(self, img, mask):
def __call__(self, img, mask, bbox):
tarr = self.seq(image=np.array(img))
nimg = Image.fromarray(tarr)
if self.with_mask:
return nimg, mask
return nimg, mask, bbox
else:
return nimg

Expand Down
6 changes: 3 additions & 3 deletions data/online_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def crop_image(
x_min_ref += x_padding
y_max_ref += y_padding
y_min_ref += y_padding

# Let's compute crop position
# The final crop coordinates will be [x_crop:x_crop+crop_size+margin,y_crop:y_crop+crop_size+margin)

Expand Down Expand Up @@ -329,7 +329,7 @@ def crop_image(
y_min_ref -= y_crop

ref_bbox = [x_min_ref, y_min_ref, x_max_ref, y_max_ref]

# invert mask if required
if inverted_mask:
mask[mask > 0] = 2
Expand All @@ -346,7 +346,7 @@ def crop_image(
int(ref_bbox[2] * (output_dim + margin) / crop_size),
int(ref_bbox[3] * (output_dim + margin) / crop_size),
]

return img, mask, ref_bbox


Expand Down
14 changes: 11 additions & 3 deletions data/unaligned_labeled_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def get_img(
print(e)
return None

A, A_label_mask = self.transform(A_img, A_label_mask)
A, A_label_mask, A_ref_bbox = self.transform(
A_img, A_label_mask
) # A_ref_bbox is the bounding box over the entire image, not the mask

if torch.any(A_label_mask > self.semantic_nclasses - 1):
warnings.warn(
Expand All @@ -107,7 +109,12 @@ def get_img(
A_label_mask[A_label_mask == 0] = 1
A_label_mask[A_label_mask == 2] = 0

result = {"A": A, "A_img_paths": A_img_path, "A_label_mask": A_label_mask}
result = {
"A": A,
"A_img_paths": A_img_path,
"A_label_mask": A_label_mask,
"A_ref_bbox": A_ref_bbox,
}

# Domain B
if B_img_path is not None:
Expand All @@ -129,7 +136,7 @@ def get_img(
% (B_label_mask_path, N_img_path)
)

B, B_label_mask = self.transform(B_img, B_label_mask)
B, B_label_mask, B_ref_bbox = self.transform(B_img, B_label_mask)
if torch.any(B_label_mask > self.semantic_nclasses - 1):
warnings.warn(
f"A label is above number of semantic classes for img {B_img_path} and label {B_label_mask_path}, label is clamped to have only {self.semantic_nclasses} classes."
Expand All @@ -150,6 +157,7 @@ def get_img(
{
"B": B,
"B_img_paths": B_img_path,
"B_ref_bbox": B_ref_bbox,
}
)
if B_label_mask_path is not None:
Expand Down
13 changes: 10 additions & 3 deletions data/unaligned_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def get_img(
print(e, "domain A data loading for ", A_img_path)
return None

A, A_label_mask = self.transform(A_img, A_label_mask)
A, A_label_mask, A_ref_bbox = self.transform(A_img, A_label_mask, A_ref_bbox)

if torch.any(A_label_mask > self.semantic_nclasses - 1):
warnings.warn(
Expand All @@ -290,7 +290,12 @@ def get_img(
if self.opt.f_s_all_classes_as_one:
A_label_mask = (A_label_mask >= 1) * 1

result = {"A": A, "A_img_paths": A_img_path, "A_label_mask": A_label_mask, "A_ref_bbox": A_ref_bbox}
result = {
"A": A,
"A_img_paths": A_img_path,
"A_label_mask": A_label_mask,
"A_ref_bbox": A_ref_bbox,
}

# Domain B
if B_img_path is not None:
Expand All @@ -311,7 +316,9 @@ def get_img(
inverted_mask=self.opt.data_inverted_mask,
single_bbox=self.opt.data_online_single_bbox,
)
B, B_label_mask = self.transform(B_img, B_label_mask)
B, B_label_mask, B_ref_bbox = self.transform(
B_img, B_label_mask, B_ref_bbox
)

if torch.any(B_label_mask > self.semantic_nclasses - 1):
warnings.warn(
Expand Down
Loading

0 comments on commit 1fa9cae

Please sign in to comment.