Skip to content

Commit

Permalink
refact (exp, docs) training scales controller and add the correspondi…
Browse files Browse the repository at this point in the history
…ng tutorial (#562)
  • Loading branch information
Joker316701882 committed Aug 23, 2021
1 parent 528205b commit 15e8725
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 77 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ python tools/eval.py -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --fp16 --
<details>
<summary>Tutorials</summary>

* [Training on custom data](docs/train_custom_data.md).
* [Training on custom data](docs/train_custom_data.md)
* [Manipulating training image size](docs/manipulate_training_image_size.md)

</details>

Expand Down
59 changes: 59 additions & 0 deletions docs/manipulate_training_image_size.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Manipulating Your Training Image Size

This tutorial explains how to control your image size when training on your own data.

## 1. Introduction

There are 3 hyperparamters control the training size:

- self.input_size = (640, 640)
- self.multiscale_range = 5
- self.random_size = (14, 26)

There is 1 hyperparameter constrols the testing size:

- self.test_size = (640, 640)

The self.input_size is suggested to set to the same value as self.test_size. By default, it is set to (640, 640) for most models and (416, 416) for yolox-tiny and yolox-nano.

## 2. Multi Scale Training

When training on your custom dataset, you can use multiscale training in 2 ways:

1. **【Default】Only specifying the self.input_size and leaving others unchanged.**

If so, the actual multiscale sizes range from:

[self.input_size[0] - self.multiscale_range\*32, self.input_size[0] + self.multiscale_range\*32]

For example, if you only set:

```python
self.input_size = (640, 640)
```

the actual multiscale range is [640 - 5*32, 640 + 5\*32], i.e., [480, 800].

You can modify self.multiscale_range to change the multiscale range.

2. **Simultaneously specifying the self.input_size and self.random_size**

```python
self.input_size = (416, 416)
self.random_size = (10, 20)
```

In this case, the actual multiscale range is [self.random_size[0]\*32, self.random_size[1]\*32], i.e., [320, 640]

**Note: You must specify the self.input_size because it is used for initializing resize aug in dataset.**

## 3. Single Scale Training

If you want to train in a single scale. You need to specify the self.input_size and self.multiscale_range=0:

```python
self.input_size = (416, 416)
self.multiscale_range = 0
```

**DO NOT** set the self.random_size.
3 changes: 2 additions & 1 deletion exps/default/nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def __init__(self):
super(Exp, self).__init__()
self.depth = 0.33
self.width = 0.25
self.scale = (0.5, 1.5)
self.input_size = (416, 416)
self.random_size = (10, 20)
self.mosaic_scale = (0.5, 1.5)
self.test_size = (416, 416)
self.mosaic_prob = 0.5
self.enable_mixup = False
Expand Down
60 changes: 0 additions & 60 deletions exps/default/yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,63 +33,3 @@ def init_yolo(M):

return self.model

def get_data_loader(self, batch_size, is_distributed, no_aug=False):
import torch.distributed as dist

from yolox.data import (
COCODataset,
DataLoader,
InfiniteSampler,
MosaicDetection,
TrainTransform,
YoloBatchSampler
)

dataset = COCODataset(
data_dir='data/COCO/',
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=50
),
)

dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=120
),
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
shear=self.shear,
perspective=self.perspective,
)

self.dataset = dataset

if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
else:
sampler = torch.utils.data.RandomSampler(self.dataset)

batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
input_dimension=self.input_size,
mosaic=not no_aug
)

dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler
train_loader = DataLoader(self.dataset, **dataloader_kwargs)

return train_loader
3 changes: 2 additions & 1 deletion exps/default/yolox_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(self):
super(Exp, self).__init__()
self.depth = 0.33
self.width = 0.375
self.scale = (0.5, 1.5)
self.input_scale = (416, 416)
self.mosaic_scale = (0.5, 1.5)
self.random_size = (10, 20)
self.test_size = (416, 416)
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
Expand Down
3 changes: 2 additions & 1 deletion exps/example/custom/nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def __init__(self):
super(Exp, self).__init__()
self.depth = 0.33
self.width = 0.25
self.scale = (0.5, 1.5)
self.input_size = (416, 416)
self.mosaic_scale = (0.5, 1.5)
self.random_size = (10, 20)
self.test_size = (416, 416)
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
Expand Down
3 changes: 2 additions & 1 deletion exps/example/yolox_voc/yolox_voc_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=Fa
preproc=TrainTransform(max_labels=120),
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
mosaic_scale=self.mosaic_scale,
mixup_scale=self.mixup_scale,
shear=self.shear,
perspective=self.perspective,
enable_mixup=self.enable_mixup,
Expand Down
2 changes: 1 addition & 1 deletion yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def after_iter(self):
self.meter.clear_meters()

# random resizing
if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0:
if (self.progress_in_iter + 1) % 10 == 0:
self.input_size = self.exp.random_resize(
self.train_loader, self.epoch, self.rank, self.is_distributed
)
Expand Down
14 changes: 7 additions & 7 deletions yolox/data/datasets/mosaicdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class MosaicDetection(Dataset):

def __init__(
self, dataset, img_size, mosaic=True, preproc=None,
degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
shear=2.0, perspective=0.0, enable_mixup=True,
mosaic_prob=1.0, mixup_prob=1.0, *args
degrees=10.0, translate=0.1, mosaic_scale=(0.5, 1.5),
mixup_scale=(0.5, 1.5), shear=2.0, perspective=0.0,
enable_mixup=True, mosaic_prob=1.0, mixup_prob=1.0, *args
):
"""
Expand All @@ -52,8 +52,8 @@ def __init__(
preproc (func):
degrees (float):
translate (float):
scale (tuple):
mscale (tuple):
mosaic_scale (tuple):
mixup_scale (tuple):
shear (float):
perspective (float):
enable_mixup (bool):
Expand All @@ -64,10 +64,10 @@ def __init__(
self.preproc = preproc
self.degrees = degrees
self.translate = translate
self.scale = scale
self.scale = mosaic_scale
self.shear = shear
self.perspective = perspective
self.mixup_scale = mscale
self.mixup_scale = mixup_scale
self.enable_mosaic = mosaic
self.enable_mixup = enable_mixup
self.mosaic_prob = mosaic_prob
Expand Down
18 changes: 14 additions & 4 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def __init__(self):
# set worker to 4 for shorter dataloader init time
self.data_num_workers = 4
self.input_size = (640, 640)
self.random_size = (14, 26)
# Actual multiscale ranges: [640-5*32, 640+5*32].
# To disable multiscale training, set the
# self.multiscale_range to 0.
self.multiscale_range = 5
# You can uncomment this line to specify a multiscale range
# self.random_size = (14, 26)
self.data_dir = None
self.train_ann = "instances_train2017.json"
self.val_ann = "instances_val2017.json"
Expand All @@ -35,8 +40,8 @@ def __init__(self):
self.mixup_prob = 1.0
self.degrees = 10.0
self.translate = 0.1
self.scale = (0.1, 2)
self.mscale = (0.8, 1.6)
self.mosaic_scale = (0.1, 2)
self.mixup_scale = (0.5, 1.5)
self.shear = 2.0
self.perspective = 0.0
self.enable_mixup = True
Expand Down Expand Up @@ -116,7 +121,8 @@ def get_data_loader(
preproc=TrainTransform(max_labels=120),
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
mosaic_scale=self.mosaic_scale,
mixup_scale=self.mixup_scale,
shear=self.shear,
perspective=self.perspective,
enable_mixup=self.enable_mixup,
Expand Down Expand Up @@ -154,6 +160,10 @@ def random_resize(self, data_loader, epoch, rank, is_distributed):

if rank == 0:
size_factor = self.input_size[1] * 1.0 / self.input_size[0]
if not hasattr(self, 'random_size'):
min_size = int(self.input_size[0] / 32) - self.multiscale_range
max_size = int(self.input_size[0] / 32) + self.multiscale_range
self.random_size = (min_size, max_size)
size = random.randint(*self.random_size)
size = (int(32 * size), 32 * int(size * size_factor))
tensor[0] = size[0]
Expand Down

0 comments on commit 15e8725

Please sign in to comment.