Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when training ssd512 to detect person(one class). #1300

Closed
LifeIsSoSolong opened this issue Aug 30, 2019 · 3 comments
Closed

Error when training ssd512 to detect person(one class). #1300

LifeIsSoSolong opened this issue Aug 30, 2019 · 3 comments

Comments

@LifeIsSoSolong
Copy link

LifeIsSoSolong commented Aug 30, 2019

Thanks for your contribution.

I extract "person" class data from coco2017 for training person detector.
But when training ssd_512_coco, I counter error::::

python tools/train.py configs/zkk/ssd512_coco_person.py --work_dir zkk_workdir/
2019-08-30 10:21:50,948 - INFO - Distributed training: False
2019-08-30 10:21:51,109 - INFO - load model from: open-mmlab://vgg16_caffe
2019-08-30 10:21:51,140 - WARNING - The model and loaded state dict do not match exactly
missing keys in source state_dict: extra.8.bias, extra.1.weight, extra.6.bias, extra.0.bias, extra.3.bias, extra.6.weight, extra.0.weight, extra.8.weight, extra.4.weight, extra.5.bias, extra.2.bias, extra.1.bias, extra.7.weight, l2_norm.weight, extra.4.bias, extra.3.weight, extra.9.weight, extra.2.weight, extra.7.bias, extra.9.bias, extra.5.weight

loading annotations into memory...
Done (t=1.10s)
creating index...
index created!
2019-08-30 10:21:54,372 - INFO - Start running, host: kaikai@kaikai, work_dir: /home/kaikai/anaconda3/envs/open-mmlab/mmdetection/zkk_workdir
2019-08-30 10:21:54,372 - INFO - workflow: [('train', 1)], max: 12 epochs
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [0,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [1,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [2,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [3,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [4,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [5,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [6,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [7,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [8,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCReduceAll.cuh line=327 error=59 : device-side assert triggered
Traceback (most recent call last):
File "tools/train.py", line 126, in
main()
File "tools/train.py", line 122, in main
logger=logger)
File "/home/kaikai/anaconda3/envs/open-mmlab/mmdetection/mmdet/apis/train.py", line 60, in train_detector
_non_dist_train(model, dataset, cfg, validate=validate)
File "/home/kaikai/anaconda3/envs/open-mmlab/mmdetection/mmdet/apis/train.py", line 221, in _non_dist_train
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
File "/home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/runner.py", line 358, in run
epoch_runner(data_loaders[i], *kwargs)
File "/home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/runner.py", line 271, in train
self.call_hook('after_train_iter')
File "/home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/runner.py", line 229, in call_hook
getattr(hook, fn_name)(self)
File "/home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/hooks/optimizer.py", line 17, in after_train_iter
runner.outputs['loss'].backward()
File "/home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/autograd/init.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/THCReduceAll.cuh:327
terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: device-side assert triggered (insert_events at /opt/conda/conda-bld/pytorch_1556653114079/work/c10/cuda/CUDACachingAllocator.cpp:564)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7ff618b29dc5 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: + 0x14792 (0x7ff615a07792 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::TensorImpl::release_resources() + 0x50 (0x7ff618b19640 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #3: + 0x3067fb (0x7ff6161267fb in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #4: + 0x36fc50 (0x7ff61618fc50 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #5: + 0x3095ea (0x7ff6161295ea in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #6: torch::autograd::deleteFunction(torch::autograd::Function
) + 0xa2 (0x7ff6161296a2 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #7: std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() + 0xa2 (0x7ff647b226f2 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #8: + 0x14024b (0x7ff647b4624b in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: + 0x1402b9 (0x7ff647b462b9 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #10: torch::autograd::Variable::Impl::release_resources() + 0x1b (0x7ff616755a6b in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch.so.1)
frame #11: + 0x14019b (0x7ff647b4619b in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #12: + 0x3bfc84 (0x7ff647dc5c84 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #13: + 0x3bfcd1 (0x7ff647dc5cd1 in /home/kaikai/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)

frame #29: __libc_start_main + 0xf0 (0x7ff6571ac830 in /lib/x86_64-linux-gnu/libc.so.6)

It seems that my coco_person data has bug, However when I try to train faster_rcnn using coco_person data ,it can work normally..........
So I think the bug maybe in my ssd_512_coco_person file, But I'm sure I change this file and faster_rcnn_person file in the same way as following :

  1. num_class = 2
    2.data_type = 'MyDataset'(create Class MyDataset in datasets/mydataset.py)
    3.data_root = 'data/coco_person'(where my coco_person data at)

the changed config file of ssd512 and faster rcnn is:
1.ssd512_coco_person.py
`
input_size = 512
model = dict(
type='SingleStageDetector',
pretrained='open-mmlab://vgg16_caffe',
backbone=dict(
type='SSDVGG',
input_size=input_size,
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20),
neck=None,
bbox_head=dict(
type='SSDHead',
input_size=input_size,
in_channels=(512, 1024, 512, 256, 256, 256, 256),
num_classes=2,
anchor_strides=(8, 16, 32, 64, 128, 256, 512),
basesize_ratio_range=(0.1, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0),
target_stds=(0.1, 0.1, 0.2, 0.2)))
cudnn_benchmark = True
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.,
ignore_iof_thr=-1,
gt_max_assign_all=False),
smoothl1_beta=1.,
allowed_border=-1,
pos_weight=-1,
neg_pos_ratio=3,
debug=False)
test_cfg = dict(
nms=dict(type='nms', iou_thr=0.45),
min_bbox_size=0,
score_thr=0.02,
max_per_img=200)

dataset_type = 'MyDataset'
data_root = 'data/coco_person/'

img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(
type='Expand',
mean=img_norm_cfg['mean'],
to_rgb=img_norm_cfg['to_rgb'],
ratio_range=(1, 4)),
dict(
type='MinIoURandomCrop',
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3),
dict(type='Resize', img_scale=(512, 512), keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=8,
workers_per_gpu=3,
train=dict(
type='RepeatDataset',
times=5,
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))

optimizer = dict(type='SGD', lr=2e-3, momentum=0.9, weight_decay=5e-4)
optimizer_config = dict()

lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 22])
checkpoint_config = dict(interval=1)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),

])

total_epochs = 24
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ssd512_coco'
load_from = None
resume_from = None
workflow = [('train', 1)]
`

2.faster_rcnn_r50_fpn_1x_person.py
`

model = dict(
type='FasterRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=2,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))

train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)

)

dataset_type = 'MyDataset'
data_root = 'data/coco_person/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]

`

Hope for your reply.
Thanks~

@Kaeseknacker
Copy link

I run into the same error. Can you set in log_config: interval=1 ?
In my case, the loss explodes after a short time:
``
[...]

2019-08-30 14:14:55,489 - INFO - Epoch [1][32/744] lr: 0.00075, eta: 5:50:59, time: 0.800, data_time: 0.464, memory: 6358, loss_cls: 87.1313, loss_bbox: 15.3167, loss: 102.4480
2019-08-30 14:14:56,277 - INFO - Epoch [1][33/744] lr: 0.00075, eta: 5:47:22, time: 0.782, data_time: 0.457, memory: 6358, loss_cls: 1604.6262, loss_bbox: 184.1971, loss: 1788.8234
2019-08-30 14:14:57,087 - INFO - Epoch [1][34/744] lr: 0.00075, eta: 5:44:13, time: 0.810, data_time: 0.481, memory: 6358, loss_cls: 39269828.0000, loss_bbox: 7429924.5000, loss: 46699752.0000
/opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/THC/THCTensorScatterGather.cu:130: void THCudaTensor_scatterKernel(TensorInfo<Real, IndexType>, TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = 1]: block: [0,0,0], thread: [0,0,0] Assertion indexValue >= 0 && indexValue < tensor.sizes[dim] failed.

[...]
``

@guoqiang0148666
Copy link

l solved this problem by modifying following files

  1. mmdet/datasets/voc.py
    CLASSES = ('person')
  2. mmdet/datasets/xml_style.py
    self.cat2label = {'person: 1'}

@alaa-shubbak
Copy link

alaa-shubbak commented Jul 23, 2021

l solved this problem by modifying following files

  1. mmdet/datasets/voc.py
    CLASSES = ('person')
  2. mmdet/datasets/xml_style.py
    self.cat2label = {'person: 1'}

I am facing the same error , i try to update those files as you mentioned , but i still face this error

by the way, my dataset type is coco dateset, why should i change the voc dataset? how is this (voc dataset) related to the RepeatedDataset?

 File "/home/alshubbak/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmdet-2.13.0-py3.7.egg/mmdet/datasets/utils.py", line 134, in _check_head
    (f'`CLASSES` in {dataset.__class__.__name__}'
AssertionError: `CLASSES` in RepeatDatasetshould be a tuple of str.Add comma if number of classes is 1 as CLASSES = (Person,)

I try to add this command on my config ('Person'), as bellow
ssd error

but nothing is changed, any suggestion ?
how can i understand what is going on behind this repeatedDataset?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants