diff --git a/README.md b/README.md index 11cae0891..3b498806f 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,8 @@ python tools/eval.py -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --fp16 --
Tutorials -* [Training on custom data](docs/train_custom_data.md). +* [Training on custom data](docs/train_custom_data.md) +* [Manipulating training image size](docs/manipulate_training_image_size.md)
diff --git a/docs/manipulate_training_image_size.md b/docs/manipulate_training_image_size.md new file mode 100644 index 000000000..a73e7f15d --- /dev/null +++ b/docs/manipulate_training_image_size.md @@ -0,0 +1,59 @@ +# Manipulating Your Training Image Size + +This tutorial explains how to control your image size when training on your own data. + +## 1. Introduction + +There are 3 hyperparamters control the training size: + +- self.input_size = (640, 640) +- self.multiscale_range = 5 +- self.random_size = (14, 26) + +There is 1 hyperparameter constrols the testing size: + +- self.test_size = (640, 640) + +The self.input_size is suggested to set to the same value as self.test_size. By default, it is set to (640, 640) for most models and (416, 416) for yolox-tiny and yolox-nano. + +## 2. Multi Scale Training + +When training on your custom dataset, you can use multiscale training in 2 ways: + +1. **【Default】Only specifying the self.input_size and leaving others unchanged.** + + If so, the actual multiscale sizes range from: + + [self.input_size[0] - self.multiscale_range\*32, self.input_size[0] + self.multiscale_range\*32] + + For example, if you only set: + + ```python + self.input_size = (640, 640) + ``` + + the actual multiscale range is [640 - 5*32, 640 + 5\*32], i.e., [480, 800]. + + You can modify self.multiscale_range to change the multiscale range. + +2. **Simultaneously specifying the self.input_size and self.random_size** + + ```python + self.input_size = (416, 416) + self.random_size = (10, 20) + ``` + + In this case, the actual multiscale range is [self.random_size[0]\*32, self.random_size[1]\*32], i.e., [320, 640] + + **Note: You must specify the self.input_size because it is used for initializing resize aug in dataset.** + +## 3. Single Scale Training + +If you want to train in a single scale. You need to specify the self.input_size and self.multiscale_range=0: + +```python +self.input_size = (416, 416) +self.multiscale_range = 0 +``` + +**DO NOT** set the self.random_size. diff --git a/exps/default/nano.py b/exps/default/nano.py index 30105c03b..6d8bcecce 100644 --- a/exps/default/nano.py +++ b/exps/default/nano.py @@ -14,8 +14,9 @@ def __init__(self): super(Exp, self).__init__() self.depth = 0.33 self.width = 0.25 - self.scale = (0.5, 1.5) + self.input_size = (416, 416) self.random_size = (10, 20) + self.mosaic_scale = (0.5, 1.5) self.test_size = (416, 416) self.mosaic_prob = 0.5 self.enable_mixup = False diff --git a/exps/default/yolov3.py b/exps/default/yolov3.py index aee81829a..cc5951646 100644 --- a/exps/default/yolov3.py +++ b/exps/default/yolov3.py @@ -33,63 +33,3 @@ def init_yolo(M): return self.model - def get_data_loader(self, batch_size, is_distributed, no_aug=False): - import torch.distributed as dist - - from yolox.data import ( - COCODataset, - DataLoader, - InfiniteSampler, - MosaicDetection, - TrainTransform, - YoloBatchSampler - ) - - dataset = COCODataset( - data_dir='data/COCO/', - json_file=self.train_ann, - img_size=self.input_size, - preproc=TrainTransform( - rgb_means=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - max_labels=50 - ), - ) - - dataset = MosaicDetection( - dataset, - mosaic=not no_aug, - img_size=self.input_size, - preproc=TrainTransform( - rgb_means=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - max_labels=120 - ), - degrees=self.degrees, - translate=self.translate, - scale=self.scale, - shear=self.shear, - perspective=self.perspective, - ) - - self.dataset = dataset - - if is_distributed: - batch_size = batch_size // dist.get_world_size() - sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0) - else: - sampler = torch.utils.data.RandomSampler(self.dataset) - - batch_sampler = YoloBatchSampler( - sampler=sampler, - batch_size=batch_size, - drop_last=False, - input_dimension=self.input_size, - mosaic=not no_aug - ) - - dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True} - dataloader_kwargs["batch_sampler"] = batch_sampler - train_loader = DataLoader(self.dataset, **dataloader_kwargs) - - return train_loader diff --git a/exps/default/yolox_tiny.py b/exps/default/yolox_tiny.py index 9ea66048c..14c191315 100644 --- a/exps/default/yolox_tiny.py +++ b/exps/default/yolox_tiny.py @@ -12,7 +12,8 @@ def __init__(self): super(Exp, self).__init__() self.depth = 0.33 self.width = 0.375 - self.scale = (0.5, 1.5) + self.input_scale = (416, 416) + self.mosaic_scale = (0.5, 1.5) self.random_size = (10, 20) self.test_size = (416, 416) self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] diff --git a/exps/example/custom/nano.py b/exps/example/custom/nano.py index 804e9310e..fb10626db 100644 --- a/exps/example/custom/nano.py +++ b/exps/example/custom/nano.py @@ -14,7 +14,8 @@ def __init__(self): super(Exp, self).__init__() self.depth = 0.33 self.width = 0.25 - self.scale = (0.5, 1.5) + self.input_size = (416, 416) + self.mosaic_scale = (0.5, 1.5) self.random_size = (10, 20) self.test_size = (416, 416) self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] diff --git a/exps/example/yolox_voc/yolox_voc_s.py b/exps/example/yolox_voc/yolox_voc_s.py index 53cdadcf1..5d9485cde 100644 --- a/exps/example/yolox_voc/yolox_voc_s.py +++ b/exps/example/yolox_voc/yolox_voc_s.py @@ -49,7 +49,8 @@ def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=Fa preproc=TrainTransform(max_labels=120), degrees=self.degrees, translate=self.translate, - scale=self.scale, + mosaic_scale=self.mosaic_scale, + mixup_scale=self.mixup_scale, shear=self.shear, perspective=self.perspective, enable_mixup=self.enable_mixup, diff --git a/yolox/core/trainer.py b/yolox/core/trainer.py index 713208d32..7b3b87026 100644 --- a/yolox/core/trainer.py +++ b/yolox/core/trainer.py @@ -248,7 +248,7 @@ def after_iter(self): self.meter.clear_meters() # random resizing - if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0: + if (self.progress_in_iter + 1) % 10 == 0: self.input_size = self.exp.random_resize( self.train_loader, self.epoch, self.rank, self.is_distributed ) diff --git a/yolox/data/datasets/mosaicdetection.py b/yolox/data/datasets/mosaicdetection.py index b998697cb..66cfcf417 100644 --- a/yolox/data/datasets/mosaicdetection.py +++ b/yolox/data/datasets/mosaicdetection.py @@ -39,9 +39,9 @@ class MosaicDetection(Dataset): def __init__( self, dataset, img_size, mosaic=True, preproc=None, - degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5), - shear=2.0, perspective=0.0, enable_mixup=True, - mosaic_prob=1.0, mixup_prob=1.0, *args + degrees=10.0, translate=0.1, mosaic_scale=(0.5, 1.5), + mixup_scale=(0.5, 1.5), shear=2.0, perspective=0.0, + enable_mixup=True, mosaic_prob=1.0, mixup_prob=1.0, *args ): """ @@ -52,8 +52,8 @@ def __init__( preproc (func): degrees (float): translate (float): - scale (tuple): - mscale (tuple): + mosaic_scale (tuple): + mixup_scale (tuple): shear (float): perspective (float): enable_mixup (bool): @@ -64,10 +64,10 @@ def __init__( self.preproc = preproc self.degrees = degrees self.translate = translate - self.scale = scale + self.scale = mosaic_scale self.shear = shear self.perspective = perspective - self.mixup_scale = mscale + self.mixup_scale = mixup_scale self.enable_mosaic = mosaic self.enable_mixup = enable_mixup self.mosaic_prob = mosaic_prob diff --git a/yolox/exp/yolox_base.py b/yolox/exp/yolox_base.py index 94471161f..df837d375 100644 --- a/yolox/exp/yolox_base.py +++ b/yolox/exp/yolox_base.py @@ -25,7 +25,12 @@ def __init__(self): # set worker to 4 for shorter dataloader init time self.data_num_workers = 4 self.input_size = (640, 640) - self.random_size = (14, 26) + # Actual multiscale ranges: [640-5*32, 640+5*32]. + # To disable multiscale training, set the + # self.multiscale_range to 0. + self.multiscale_range = 5 + # You can uncomment this line to specify a multiscale range + # self.random_size = (14, 26) self.data_dir = None self.train_ann = "instances_train2017.json" self.val_ann = "instances_val2017.json" @@ -35,8 +40,8 @@ def __init__(self): self.mixup_prob = 1.0 self.degrees = 10.0 self.translate = 0.1 - self.scale = (0.1, 2) - self.mscale = (0.8, 1.6) + self.mosaic_scale = (0.1, 2) + self.mixup_scale = (0.5, 1.5) self.shear = 2.0 self.perspective = 0.0 self.enable_mixup = True @@ -116,7 +121,8 @@ def get_data_loader( preproc=TrainTransform(max_labels=120), degrees=self.degrees, translate=self.translate, - scale=self.scale, + mosaic_scale=self.mosaic_scale, + mixup_scale=self.mixup_scale, shear=self.shear, perspective=self.perspective, enable_mixup=self.enable_mixup, @@ -154,6 +160,10 @@ def random_resize(self, data_loader, epoch, rank, is_distributed): if rank == 0: size_factor = self.input_size[1] * 1.0 / self.input_size[0] + if not hasattr(self, 'random_size'): + min_size = int(self.input_size[0] / 32) - self.multiscale_range + max_size = int(self.input_size[0] / 32) + self.multiscale_range + self.random_size = (min_size, max_size) size = random.randint(*self.random_size) size = (int(32 * size), 32 * int(size * size_factor)) tensor[0] = size[0]