Skip to content

Commit

Permalink
feat: conditioning for palette
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed May 22, 2023
1 parent 4673dce commit b9854ee
Show file tree
Hide file tree
Showing 23 changed files with 568 additions and 134 deletions.
1 change: 0 additions & 1 deletion data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(self, opt, phase):
phase (str) -- can be train,test or validation.
"""
self.phase = phase
print("self.phase", self.phase)
self.opt = opt

self.use_domain_B = not "self_supervised" in self.opt.data_dataset_mode
Expand Down
10 changes: 9 additions & 1 deletion data/online_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def crop_image(
bbox_ref_id=-1,
inverted_mask=False,
single_bbox=False,
override_class=-1,
):

margin = context_pixels * 2
Expand Down Expand Up @@ -61,7 +62,12 @@ def crop_image(

for line in f:
if len(line) > 2: # to make sure the current line is a real bbox
if select_cat != -1:
if override_class != -1:
bbox = line.split()
bbox[0] = override_class
line = " ".join(bbox)

elif select_cat != -1:
bbox = line.split()
cat = int(bbox[0])
if cat != select_cat:
Expand Down Expand Up @@ -161,6 +167,7 @@ def crop_image(
mask[ymin:ymax, xmin:xmax] = np.full((ymax - ymin, xmax - xmin), cat)

if i == idx_bbox_ref:
cat_ref = cat
x_min_ref = xmin
x_max_ref = xmax
y_min_ref = ymin
Expand Down Expand Up @@ -341,6 +348,7 @@ def crop_image(

# resize ref_bbox to output_dim + margin
ref_bbox = [
cat_ref,
int(ref_bbox[0] * (output_dim + margin) / crop_size),
int(ref_bbox[1] * (output_dim + margin) / crop_size),
int(ref_bbox[2] * (output_dim + margin) / crop_size),
Expand Down
67 changes: 67 additions & 0 deletions data/self_supervised_labeled_mask_cls_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os.path
from data.unaligned_labeled_mask_cls_dataset import UnalignedLabeledMaskClsDataset
from data.online_creation import fill_mask_with_random, fill_mask_with_color
from PIL import Image
import numpy as np
import torch
import warnings


class SelfSupervisedLabeledMaskClsDataset(UnalignedLabeledMaskClsDataset):
"""
This dataset class can create paired datasets with mask labels from only one domain.
"""

def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
super().__init__(opt)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {})
else:
raise Exception(
"self supervised dataset: no self supervised method specified"
)

result.update(
{
"A": A_img,
"B": result["A"],
"B_img_paths": result["A_img_paths"],
"B_label_mask": result["A_label_mask"].clone(),
"B_label_cls": result["A_label_cls"].clone(),
}
)
except Exception as e:
print(e, "self supervised data loading")
return None

return result
69 changes: 69 additions & 0 deletions data/self_supervised_labeled_mask_cls_online_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os.path
from data.unaligned_labeled_mask_cls_online_dataset import (
UnalignedLabeledMaskClsOnlineDataset,
)
from data.online_creation import fill_mask_with_random, fill_mask_with_color
from PIL import Image
import numpy as np
import torch
import warnings


class SelfSupervisedLabeledMaskClsOnlineDataset(UnalignedLabeledMaskClsOnlineDataset):
"""
This dataset class can create paired datasets with mask labels from only one domain.
"""

def __init__(self, opt, phase):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
super().__init__(opt, phase)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {})
else:
raise Exception(
"self supervised dataset: no self supervised method specified"
)

result.update(
{
"A": A_img,
"B": result["A"],
"B_img_paths": result["A_img_paths"],
"B_label_mask": result["A_label_mask"].clone(),
"B_label_cls": result["A_label_cls"].clone(),
}
)
except Exception as e:
print(e, "self supervised data loading")
return None

return result
1 change: 1 addition & 0 deletions data/self_supervised_labeled_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_img(
B_label_mask_path,
B_label_cls,
index,
clamp_semantics=False,
)

try:
Expand Down
1 change: 1 addition & 0 deletions data/self_supervised_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_img(
B_label_mask_path,
B_label_cls,
index,
clamp_semantics=False,
)

try:
Expand Down
14 changes: 7 additions & 7 deletions data/unaligned_labeled_mask_cls_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ def __init__(self, opt, phase):
assert len(label_split) == 2
self.A_label_cls.append(label_split[0])

for label in self.B_label:
label_split = label.split(" ")
assert len(label_split) == 2
self.B_label_cls.append(label_split[0])

self.A_label_mask_paths
if self.use_domain_B and hasattr(self, "B_label"):
for label in self.B_label:
label_split = label.split(" ")
assert len(label_split) == 2
self.B_label_cls.append(label_split[0])

def get_img(
self,
Expand Down Expand Up @@ -52,6 +51,7 @@ def get_img(

# TODO : check how to deal with float for regression
return_dict["A_label_cls"] = torch.tensor(int(A_label_cls))
return_dict["B_label_cls"] = torch.tensor(int(B_label_cls))
if B_label_cls is not None:
return_dict["B_label_cls"] = torch.tensor(int(B_label_cls))

return return_dict
9 changes: 3 additions & 6 deletions data/unaligned_labeled_mask_cls_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ def get_img(
if return_dict is None:
return None

# To remove
A_label_cls = 1
B_label_cls = 1

return_dict["A_label_cls"] = A_label_cls
return_dict["B_label_cls"] = B_label_cls
return_dict["A_label_cls"] = self.cat_A_ref_bbox
if hasattr(self, "B_label_cls"):
return_dict["B_label_cls"] = self.cat_B_ref_bbox

return return_dict
13 changes: 9 additions & 4 deletions data/unaligned_labeled_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_img(
B_label_mask_path=None,
B_label_cls=None,
index=None,
clamp_semantics=True,
):

# Domain A
Expand All @@ -94,10 +95,9 @@ def get_img(
A_img, A_label_mask
) # A_ref_bbox is the bounding box over the entire image, not the mask

if torch.any(A_label_mask > self.semantic_nclasses - 1):
if clamp_semantics and torch.any(A_label_mask > self.semantic_nclasses - 1):
warnings.warn(
"A label is above number of semantic classes for img %s and label %s"
% (A_img_path, A_label_mask_path)
f"A label is above number of semantic classes for img {A_img_path} and label {A_label_mask_path}, label is clamped to have only {self.semantic_nclasses} classes."
)
A_label_mask = torch.clamp(A_label_mask, max=self.semantic_nclasses - 1)

Expand Down Expand Up @@ -137,13 +137,18 @@ def get_img(
)

B, B_label_mask, B_ref_bbox = self.transform(B_img, B_label_mask)
if torch.any(B_label_mask > self.semantic_nclasses - 1):

if clamp_semantics and torch.any(
B_label_mask > self.semantic_nclasses - 1
):

warnings.warn(
f"A label is above number of semantic classes for img {B_img_path} and label {B_label_mask_path}, label is clamped to have only {self.semantic_nclasses} classes."
)
B_label_mask = torch.clamp(
B_label_mask, max=self.semantic_nclasses - 1
)

else:
B = self.transform_noseg(B_img)
B_label_mask = []
Expand Down
13 changes: 11 additions & 2 deletions data/unaligned_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def get_img(
B_label_mask_path=None,
B_label_cls=None,
index=None,
clamp_semantics=True,
):
# Domain A

Expand All @@ -183,14 +184,16 @@ def get_img(
inverted_mask=self.opt.data_inverted_mask,
single_bbox=self.opt.data_online_single_bbox,
)
self.cat_A_ref_bbox = torch.tensor(A_ref_bbox[0])
A_ref_bbox = A_ref_bbox[1:]

except Exception as e:
print(e, "domain A data loading for ", A_img_path)
return None

A, A_label_mask, A_ref_bbox = self.transform(A_img, A_label_mask, A_ref_bbox)

if torch.any(A_label_mask > self.semantic_nclasses - 1):
if clamp_semantics and torch.any(A_label_mask > self.semantic_nclasses - 1):
warnings.warn(
f"A label is above number of semantic classes for img {A_img_path} and label {A_label_mask_path}, label is clamped to have only {self.semantic_nclasses} classes."
)
Expand Down Expand Up @@ -225,11 +228,17 @@ def get_img(
inverted_mask=self.opt.data_inverted_mask,
single_bbox=self.opt.data_online_single_bbox,
)

self.cat_B_ref_bbox = torch.tensor(B_ref_bbox[0])
B_ref_bbox = B_ref_bbox[1:]

B, B_label_mask, B_ref_bbox = self.transform(
B_img, B_label_mask, B_ref_bbox
)

if torch.any(B_label_mask > self.semantic_nclasses - 1):
if clamp_semantics and torch.any(
B_label_mask > self.semantic_nclasses - 1
):
warnings.warn(
f"A label is above number of semantic classes for img {B_img_path} and label {B_label_mask_path}, label is clamped to have only {self.semantic_nclasses} classes."
)
Expand Down
Loading

0 comments on commit b9854ee

Please sign in to comment.