Skip to content

Commit

Permalink
Merge pull request #62 from TasinIshmam/fix-multiple-object
Browse files Browse the repository at this point in the history
fix #60 - Support for images with multiple objects in Dataset class
  • Loading branch information
alankbi authored Oct 23, 2020
2 parents 480f13f + 1c391c5 commit eaeb889
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 33 deletions.
55 changes: 36 additions & 19 deletions detecto/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, label_data, image_folder=None, transform=None):
the XML label files or a CSV file containing the label data.
If a CSV file, the file should have the following columns in
order: ``filename``, ``width``, ``height``, ``class``, ``xmin``,
``ymin``, ``xmax``, and ``ymax``. See
``ymin``, ``xmax``, ``ymax`` and ``image_id``. See
:func:`detecto.utils.xml_to_csv` to generate CSV files in this
format from XML label files.
:type label_data: str
Expand Down Expand Up @@ -136,7 +136,8 @@ def __init__(self, label_data, image_folder=None, transform=None):

# Returns the length of this dataset
def __len__(self):
return len(self._csv)
# number of entries == number of unique image_ids in csv.
return len(self._csv['image_id'].unique().tolist())

# Is what allows you to index the dataset, e.g. dataset[0]
# dataset[index] returns a tuple containing the image and the targets dict
Expand All @@ -145,22 +146,29 @@ def __getitem__(self, idx):
idx = idx.tolist()

# Read in the image from the file name in the 0th column
img_name = os.path.join(self._root_dir, self._csv.iloc[idx, 0])
object_entries = self._csv.loc[self._csv['image_id'] == idx]

img_name = os.path.join(self._root_dir, object_entries.iloc[0, 0])
image = read_image(img_name)

# Read in xmin, ymin, xmax, and ymax
box = self._csv.iloc[idx, 4:]
box = torch.tensor(box).view(1, 4)
boxes = []
labels = []
for object_idx, row in object_entries.iterrows():
# Read in xmin, ymin, xmax, and ymax
box = self._csv.iloc[object_idx, 4:8]
boxes.append(box)
# Read in the labe
label = self._csv.iloc[object_idx, 3]
labels.append(label)

# Read in the label
label = self._csv.iloc[idx, 3]
boxes = torch.tensor(boxes).view(-1, 4)

targets = {'boxes': box, 'labels': label}
targets = {'boxes': boxes, 'labels': labels}

# Perform transformations
if self.transform:
width = self._csv.loc[idx, 'width']
height = self._csv.loc[idx, 'height']
width = object_entries.iloc[0, 1]
height = object_entries.iloc[0, 2]

# Apply the transforms manually to be able to deal with
# transforms like Resize or RandomHorizontalFlip
Expand Down Expand Up @@ -189,15 +197,20 @@ def __getitem__(self, idx):
if isinstance(t, transforms.RandomHorizontalFlip):
if random.random() < random_flip:
image = transforms.RandomHorizontalFlip(1)(image)
# Flip box's x-coordinates
box[0, 0] = width - box[0, 0]
box[0, 2] = width - box[0, 2]
box[0, 0], box[0, 2] = box[0, (2, 0)]
for idx, box in enumerate(targets['boxes']):
# Flip box's x-coordinates
box[0] = width - box[0]
box[2] = width - box[2]
box[[0,2]] = box[[2,0]]
targets['boxes'][idx] = box
else:
image = t(image)

# Scale down box if necessary
targets['boxes'] = (box / scale_factor).long()
if scale_factor != 1.0:
for idx, box in enumerate(targets['boxes']):
box = (box / scale_factor).long()
targets['boxes'][idx] = box

return image, targets

Expand Down Expand Up @@ -329,6 +342,7 @@ def predict(self, images):

return results[0] if is_single_image else results


def predict_top(self, images):
"""Takes in an image or list of images and returns the top
scoring predictions for each detected label in each image.
Expand Down Expand Up @@ -568,9 +582,12 @@ def load(file, classes):
# Converts all string labels in a list of target dicts to
# their corresponding int mappings
def _convert_to_int_labels(self, targets):
for target in targets:
# Convert string labels to integer mapping
target['labels'] = torch.tensor(self._int_mapping[target['labels']]).view(1)
for idx, target in enumerate(targets):
# get all string labels for objects in a single image
labels_array = target['labels']
# convert string labels into one hot encoding
labels_int_array = [self._int_mapping[class_name] for class_name in labels_array]
target['labels'] = torch.tensor(labels_int_array)

# Sends all images and targets to the same device as the model
def _to_device(self, images, targets):
Expand Down
22 changes: 11 additions & 11 deletions detecto/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
def test_dataset():
# Test the format of the values returned by indexing the dataset
dataset = get_dataset()
assert len(dataset) == 2
assert len(dataset) == 1 # there is only one image in the dataset (label.xml)
assert isinstance(dataset[0][0], torch.Tensor)
assert isinstance(dataset[0][1], dict)
assert dataset[0][0].shape == (3, 1080, 1720)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]
assert dataset[0][1]['boxes'].shape == (1, 4)
assert dataset[0][1]['labels'] == 'start_tick'
assert dataset[0][1]['boxes'].shape == (2, 4)
assert dataset[0][1]['labels'] == ['start_tick', 'start_gate']

transform = transforms.Compose([
transforms.ToPILImage(),
Expand All @@ -29,20 +29,20 @@ def test_dataset():

# Test that the transforms are properly applied
dataset = get_dataset(transform=transform)
assert dataset[1][0].shape == (3, 108, 172)
assert torch.all(dataset[1][1]['boxes'][0] == torch.tensor([6, 41, 171, 107]))
assert dataset[0][0].shape == (3, 108, 172)
assert torch.all(dataset[0][1]['boxes'][1] == torch.tensor([6, 41, 171, 107]))

# Test works when given an XML folder
path = os.path.dirname(__file__)
input_folder = os.path.join(path, 'static')

dataset = Dataset(input_folder, input_folder)
assert len(dataset) == 2
assert len(dataset) == 1
assert dataset[0][0].shape == (3, 1080, 1720)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]

dataset = Dataset(input_folder)
assert len(dataset) == 2
assert len(dataset) == 1
assert dataset[0][0].shape == (3, 1080, 1720)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]

Expand Down Expand Up @@ -75,14 +75,14 @@ def test_dataloader():
iterations += 1

assert isinstance(data, tuple)
assert len(data) == 2
assert len(data) == 2 # data[0] = image tensor, data[1] = targets dictionary
assert isinstance(data[0], list)
assert len(data[0]) == 2
assert len(data[0]) == 1 # only one image in data[0] since label.xml contains one image only.

assert isinstance(data[0][0], torch.Tensor)
assert isinstance(data[0][1], torch.Tensor)
assert 'boxes' in dataset[0][1] and 'labels' in dataset[1][1]
assert 'boxes' in data[1][0] and 'labels' in data[1][0]

assert 'boxes' in dataset[0][1] and 'labels' in dataset[0][1]
assert iterations == 1


Expand Down
9 changes: 6 additions & 3 deletions detecto/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def xml_to_csv(xml_folder, output_file=None):
"""Converts a folder of XML label files into a pandas DataFrame and/or
CSV file, which can then be used to create a :class:`detecto.core.Dataset`
object. Each XML file should correspond to an image and contain the image
name, image size, and the names and bounding boxes of the objects in the
name, image size, image_id and the names and bounding boxes of the objects in the
image, if any. Extraneous data in the XML files will simply be ignored.
See :download:`here <../_static/example.xml>` for an example XML file.
For an image labeling tool that produces XML files in this format,
Expand All @@ -249,6 +249,7 @@ def xml_to_csv(xml_folder, output_file=None):
"""

xml_list = []
image_id = 0
# Loop through every XML file
for xml_file in glob(xml_folder + '/*.xml'):
tree = ET.parse(xml_file)
Expand All @@ -266,11 +267,13 @@ def xml_to_csv(xml_folder, output_file=None):

# Add image file name, image size, label, and box coordinates to CSV file
row = (filename, width, height, label, int(float(box[0].text)),
int(float(box[1].text)), int(float(box[2].text)), int(float(box[3].text)))
int(float(box[1].text)), int(float(box[2].text)), int(float(box[3].text)), image_id)
xml_list.append(row)

image_id += 1

# Save as a CSV file
column_names = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
column_names = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax', 'image_id']
xml_df = pd.DataFrame(xml_list, columns=column_names)

if output_file is not None:
Expand Down

0 comments on commit eaeb889

Please sign in to comment.