Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: cls model works #239

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,5 @@ docs/.vitepress/dist

# sscma examples
examples/
work_dir/
data/
Empty file removed configs/_base_/dataset.py
Empty file.
Empty file removed configs/_base_/scheduler.py
Empty file.
18 changes: 18 additions & 0 deletions configs/models/timm_resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='TimmClassifier',
model_name='resnet50',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
pretrained=True),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
)
)

4 changes: 4 additions & 0 deletions configs/resnet50_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'models/timm_resnet50.py', 'datasets/imagenet_bs32.py',
'schedules/imagenet_bs256.py', '_base_/default_runtime.py'
]
17 changes: 17 additions & 0 deletions configs/schedules/imagenet_bs256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# optimizer
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=256)

1 change: 1 addition & 0 deletions sscma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .evaluation import Evaluator
from .version import __version__, version_info


runner.Evaluator = Evaluator

mmengine_minimum_version = '0.3.0'
Expand Down
4 changes: 3 additions & 1 deletion sscma/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .imagenet import ImageNet
__all__ = ['ImageNet']
from .data_preprocessor import ClsDataPreprocessor

__all__ = ['ImageNet','ClsDataPreprocessor']
271 changes: 271 additions & 0 deletions sscma/datasets/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@


# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.model import (BaseDataPreprocessor,
stack_batch)
from mmengine.registry import MODELS
from mmengine.registry import TRANSFORMS

from sscma.structures import (DataSample, MultiTaskDataSample,
batch_label_to_onehot, cat_batch_labels,
tensor_split)

from typing import Callable, Union
import numpy as np



class RandomBatchAugment:
"""Randomly choose one batch augmentation to apply.

Args:
augments (Callable | dict | list): configs of batch
augmentations.
probs (float | List[float] | None): The probabilities of each batch
augmentations. If None, choose evenly. Defaults to None.

Example:
>>> import torch
>>> import torch.nn.functional as F
>>> from mmpretrain.models import RandomBatchAugment
>>> augments_cfg = [
... dict(type='CutMix', alpha=1.),
... dict(type='Mixup', alpha=1.)
... ]
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
>>> imgs = torch.rand(16, 3, 32, 32)
>>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10)
>>> imgs, label = batch_augment(imgs, label)

.. note ::

To decide which batch augmentation will be used, it picks one of
``augments`` based on the probabilities. In the example above, the
probability to use CutMix is 0.5, to use Mixup is 0.3, and to do
nothing is 0.2.
"""

def __init__(self, augments: Union[Callable, dict, list], probs=None):
if not isinstance(augments, (tuple, list)):
augments = [augments]

self.augments = []
for aug in augments:
if isinstance(aug, dict):
self.augments.append(TRANSFORMS.build(aug))
else:
self.augments.append(aug)

if isinstance(probs, float):
probs = [probs]

if probs is not None:
assert len(augments) == len(probs), \
'``augments`` and ``probs`` must have same lengths. ' \
f'Got {len(augments)} vs {len(probs)}.'
assert sum(probs) <= 1, \
'The total probability of batch augments exceeds 1.'
self.augments.append(None)
probs.append(1 - sum(probs))

self.probs = probs

def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor):
"""Randomly apply batch augmentations to the batch inputs and batch
data samples."""
aug_index = np.random.choice(len(self.augments), p=self.probs)
aug = self.augments[aug_index]

if aug is not None:
return aug(batch_input, batch_score)
else:
return batch_input, batch_score.float()



@MODELS.register_module()
class ClsDataPreprocessor(BaseDataPreprocessor):
"""Image pre-processor for classification tasks.

Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,

1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.

It provides the data pre-processing as follows

- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.

Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
to_onehot (bool): Whether to generate one-hot format gt-labels and set
to data samples. Defaults to False.
num_classes (int, optional): The number of classes. Defaults to None.
batch_augments (dict, optional): The batch augmentations settings,
including "augments" and "probs". For more details, see
:class:`mmpretrain.models.RandomBatchAugment`.
"""

def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Number = 0,
to_rgb: bool = False,
to_onehot: bool = False,
num_classes: Optional[int] = None,
batch_augments: Optional[dict] = None):
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.to_rgb = to_rgb
self.to_onehot = to_onehot
self.num_classes = num_classes

if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both `mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False

if batch_augments:
self.batch_augments = RandomBatchAugment(**batch_augments)
if not self.to_onehot:
from mmengine.logging import MMLogger
MMLogger.get_current_instance().info(
'Because batch augmentations are enabled, the data '
'preprocessor automatically enables the `to_onehot` '
'option to generate one-hot format labels.')
self.to_onehot = True
else:
self.batch_augments = None

def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.

Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.

Returns:
dict: Data in the same format as the model input.
"""
inputs = self.cast_data(data['inputs'])

if isinstance(inputs, torch.Tensor):
# The branch if use `default_collate` as the collate_fn in the
# dataloader.

# ------ To RGB ------
if self.to_rgb and inputs.size(1) == 3:
inputs = inputs.flip(1)

# -- Normalization ---
inputs = inputs.float()
if self._enable_normalize:
inputs = (inputs - self.mean) / self.std

# ------ Padding -----
if self.pad_size_divisor > 1:
h, w = inputs.shape[-2:]

target_h = math.ceil(
h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(
w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
self.pad_value)
else:
# The branch if use `pseudo_collate` as the collate_fn in the
# dataloader.

processed_inputs = []
for input_ in inputs:
# ------ To RGB ------
if self.to_rgb and input_.size(0) == 3:
input_ = input_.flip(0)

# -- Normalization ---
input_ = input_.float()
if self._enable_normalize:
input_ = (input_ - self.mean) / self.std

processed_inputs.append(input_)
# Combine padding and stack
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
self.pad_value)

data_samples = data.get('data_samples', None)
sample_item = data_samples[0] if data_samples is not None else None

if isinstance(sample_item, DataSample):
batch_label = None
batch_score = None

if 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(gt_labels)
batch_label = batch_label.to(self.device)
if 'gt_score' in sample_item:
gt_scores = [sample.gt_score for sample in data_samples]
batch_score = torch.stack(gt_scores).to(self.device)
elif self.to_onehot and 'gt_label' in sample_item:
assert batch_label is not None, \
'Cannot generate onehot format labels because no labels.'
num_classes = self.num_classes or sample_item.get(
'num_classes')
assert num_classes is not None, \
'Cannot generate one-hot format labels because not set ' \
'`num_classes` in `data_preprocessor`.'
batch_score = batch_label_to_onehot(
batch_label, label_indices, num_classes).to(self.device)

# ----- Batch Augmentations ----
if (training and self.batch_augments is not None
and batch_score is not None):
inputs, batch_score = self.batch_augments(inputs, batch_score)

# ----- scatter labels and scores to data samples ---
if batch_label is not None:
for sample, label in zip(
data_samples, tensor_split(batch_label,
label_indices)):
sample.set_gt_label(label)
if batch_score is not None:
for sample, score in zip(data_samples, batch_score):
sample.set_gt_score(score)
elif isinstance(sample_item, MultiTaskDataSample):
data_samples = self.cast_data(data_samples)

return {'inputs': inputs, 'data_samples': data_samples}
30 changes: 0 additions & 30 deletions sscma/datasets/transforms.py

This file was deleted.

8 changes: 8 additions & 0 deletions sscma/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .processing import RandomResizedCrop,ResizeEdge,RandomFlip,CenterCrop
from .loading import LoadImageFromFile,LoadAnnotations
from .formatting import PackInputs,PackMultiTaskInputs,Transpose,NumpyToPIL,PILToNumpy,Collect


__all__ = ['RandomResizedCrop','ResizeEdge','RandomFlip','CenterCrop',
'LoadImageFromFile','LoadAnnotations',
'PackInputs','PackMultiTaskInputs','Transpose','NumpyToPIL','PILToNumpy','Collect']
Loading