Skip to content

Commit

Permalink
[Feature] Support PGD Head (#964)
Browse files Browse the repository at this point in the history
* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import (#839)

* [Enhance]  refactor  iou_neg_piecewise_sampler.py (#842)

* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* refactor iou_neg_piecewise_sampler.py

* add docstring

* modify docstring

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>

* [Feature] Add roipooling cuda ops (#843)

* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* add roipooling cuda ops

* add roi extractor

* add test_roi_extractor unittest

* Modify setup.py to install roipooling ops

* modify docstring

* remove enlarge bbox in roipoint pooling

* add_roipooling_ops

* modify docstring

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>

* [Refactor] Refactor code structure and docstrings (#803)

* refactor points_in_boxes

* Merge same functions of three boxes

* More docstring fixes and unify x/y/z size

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Add "optional" and fix "Default"

* Remove None in function param type

* Fix unittest

* Add comments for NMS functions

* Merge methods of Points

* Add unittest

* Add optional and default value

* Fix box conversion and add unittest

* Fix comments

* Add unit test

* Indent

* Fix CI

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Remove useless \\

* Add unit test for box bev

* More unit tests and refine docstrings in box_np_ops

* Fix comment

* Add deprecation warning

* [Feature] PointXYZWHLRBBoxCoder (#856)

* support PointBasedBoxCoder

* fix unittest bug

* support unittest in gpu

* support unittest in gpu

* modified docstring

* add args

* add args

* [Enhance] Change Groupfree3D config (#855)

* All mods

* PointSample

* PointSample

* [Doc] Add tutorials/data_pipeline Chinese version (#827)

* [Doc] Add tutorials/data_pipeline Chinese version

* refine doc

* Use the absolute link

* Use the absolute link

Co-authored-by: Tai-Wang <tab_wang@outlook.com>

* [Doc] Add Chinese doc for `scannet_det.md` (#836)

* Part

* Complete

* Fix comments

* Fix comments

* [Doc] Add Chinese doc for `waymo_det.md` (#859)

* Add complete translation

* Refinements

* Fix comments

* Fix a minor typo

Co-authored-by: Tai-Wang <tab_wang@outlook.com>

* Remove 2D annotations on Lyft (#867)

* Add header for files (#869)

* Add header for files

* Add header for files

* Add header for files

* Add header for files

* [fix] fix typos (#872)

* Fix 3 unworking configs (#882)

* [Fix] Fix `index.rst` for Chinese docs (#873)

* Fix index.rst for zh docs

* Change switch language

* [Fix] Centerpoint head nested list transpose  (#879)

* FIX Transpose nested lists without Numpy

* Removed unused Numpy import

* [Enhance] Update PointFusion (#791)

* update point fusion

* remove LIDAR hardcode

* move get_proj_mat_by_coord_type to utils

* fix lint

* remove todo

* fix lint

* [Doc] Add nuscenes_det.md Chinese version (#854)

* add nus chinese doc

* add nuScenes Chinese doc

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [Fix] Fix RegNet pretrained weight loading (#889)

* Fix regnet pretrained weight loading

* Remove unused file

* Fix centerpoint tta (#892)

* [Enhance] Add benchmark regression script (#808)

* Initial commit

* [Feature] Support DGCNN (v1.0.0.dev0) (#896)

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* fix typo

* fix typo

* fix typo

* del gf&fa registry (wo reuse pointnet module)

* fix typo

* add benchmark and add copyright header (for DGCNN only)

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* support dgcnn

* Change cam rot_3d_in_axis (#906)

* [Doc] Add coord sys tutorial pic and change links to dev branch (#912)

* Modify link branch and add pic

* Fix pic

* [Feature] add kitti AP40 evaluation metric (v1.0.0.dev0) (#927)

* Add citation (#901)

* [Feature] Add python3.9 in CI (#900)

* Add python3.0 in CI

* Add python3.0 in CI

* Bump to v0.17.0 (#898)

* Update README.md

* Update README_zh-CN.md

* Update version.py

* Update getting_started.md

* Update getting_started.md

* Update changelog.md

* Remove "recent" in the news

* Remove "recent" in the news

* Fix comments

* [Docs] Fix the version of sphinx (#902)

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* Fix sphinx version

* add AP40

* add unitest

* add unitest

* seperate AP11 and AP40

* fix some typos

Co-authored-by: dingchang <hudingchang.vendor@sensetime.com>
Co-authored-by: Tai-Wang <tab_wang@outlook.com>

* [Feature] add smoke backbone neck (#939)

* add smoke detecotor and it's backbone and neck

* typo fix

* fix typo

* add docstring

* fix typo

* fix comments

* fix comments

* fix comments

* fix typo

* fix typo

* fix

* fix typo

* fix docstring

* refine feature

* fix typo

* use Basemodule in Neck

* [Refactor] Refactor the transformation from image to camera coordinates (#938)

* Refactor points_img2cam

* Refine docstring

* Support array converter and add unit tests

* [Feature] FCOS3D BBox Coder (#940)

* FCOS3D BBox Coder

* Add unit tests

* Change the value from long to float/double

* Rename bbox_out as bbox

* Add comments to forward returns

* Support PGD BBox Coder

* Refine docstring

* Add uncertain l1 loss and its unit tests

* [Feature] PGD BBox Coder (#948)

* Support PGD BBox Coder

* Refine docstring

* PGD Head initialized

* Refactor init methods, fix legacy variable names

* [Feature] Support Uncertain L1 Loss (#950)

* Add uncertain l1 loss and its unit tests

* Remove mmcv.jit and refine docstrings

* [Fix] Fix visualization in KITTI dataset (#956)

* fix bug to support kitti vis

* fix

* Refine variable names and docstrings

* Add unit tests and fix some minor bugs

* Refine assertion messages

* Fix typo in the docs_zh-CN

* Use Pretrain init and remove unused init_cfg in FCOS3D

* Fix the comments for the input_modality in the dataset config

* Fix minor bugs in pgd_bbox_coder and incorrect setting for uncertain loss, use original init

* Add explanations for code_weights

* Adjust the unit test for pgd bbox coder

* Remove unused codes

* Add mono3d metric into the gather_models and fix bugs

* Involve the value assignment of loss_dict into the computing procedure

* Fix incorrect loss_depth

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: Xi Liu <75658786+xiliu8006@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>
Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com>
Co-authored-by: dingchang <hudingchang.vendor@sensetime.com>
Co-authored-by: 谢恩泽 <Johnny_ez@163.com>
Co-authored-by: Robin Karlsson <34254153+robin-karlsson0@users.noreply.github.com>
Co-authored-by: Danila Rukhovich <danrukh@gmail.com>
Co-authored-by: ChaimZhu <zhuchenming@pjlab.org.cn>
  • Loading branch information
10 people authored Nov 1, 2021
1 parent 1304548 commit 36aa4cb
Show file tree
Hide file tree
Showing 11 changed files with 1,580 additions and 45 deletions.
37 changes: 23 additions & 14 deletions .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
Usage:
python gather_models.py ${root_path} ${out_dir}
Example:
python gather_models.py \
work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d \
work_dirs/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d
Note that before running the above command, rename the directory with the
config name if you did not use the default directory name, create
a corresponding directory 'pgd' under the above path and put the used config
into it.
"""

import argparse
Expand Down Expand Up @@ -36,16 +46,18 @@
RESULTS_LUT = {
'coco': ['bbox_mAP', 'segm_mAP'],
'nus': ['pts_bbox_NuScenes/NDS', 'NDS'],
'kitti-3d-3class': [
'KITTI/Overall_3D_moderate',
'Overall_3D_moderate',
],
'kitti-3d-3class': ['KITTI/Overall_3D_moderate', 'Overall_3D_moderate'],
'kitti-3d-car': ['KITTI/Car_3D_moderate_strict', 'Car_3D_moderate_strict'],
'lyft': ['score'],
'scannet_seg': ['miou'],
's3dis_seg': ['miou'],
'scannet': ['mAP_0.50'],
'sunrgbd': ['mAP_0.50']
'sunrgbd': ['mAP_0.50'],
'kitti-mono3d': [
'img_bbox/KITTI/Car_3D_AP40_moderate_strict',
'Car_3D_AP40_moderate_strict'
],
'nus-mono3d': ['img_bbox_NuScenes/NDS', 'NDS']
}


Expand Down Expand Up @@ -145,15 +157,13 @@ def main():
# and parse the best performance
model_infos = []
for used_config in used_configs:
exp_dir = osp.join(models_root, used_config)

# get logs
log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0]
log_json_path = glob.glob(osp.join(models_root, '*.log.json'))[0]
log_txt_path = glob.glob(osp.join(models_root, '*.log'))[0]
model_performance = get_best_results(log_json_path)
final_epoch = model_performance['epoch']
final_model = 'epoch_{}.pth'.format(final_epoch)
model_path = osp.join(exp_dir, final_model)
model_path = osp.join(models_root, final_model)

# skip if the model is still training
if not osp.exists(model_path):
Expand Down Expand Up @@ -182,7 +192,7 @@ def main():
model_name = model['config'].split('/')[-1].rstrip(
'.py') + '_' + model['model_time']
publish_model_path = osp.join(model_publish_dir, model_name)
trained_model_path = osp.join(models_root, model['config'],
trained_model_path = osp.join(models_root,
'epoch_{}.pth'.format(model['epochs']))

# convert model
Expand All @@ -191,11 +201,10 @@ def main():

# copy log
shutil.copy(
osp.join(models_root, model['config'], model['log_json_path']),
osp.join(models_root, model['log_json_path']),
osp.join(model_publish_dir, f'{model_name}.log.json'))
shutil.copy(
osp.join(models_root, model['config'],
model['log_json_path'].rstrip('.json')),
osp.join(models_root, model['log_json_path'].rstrip('.json')),
osp.join(model_publish_dir, f'{model_name}.log'))

# copy config to guarantee reproducibility
Expand Down
6 changes: 4 additions & 2 deletions configs/_base_/models/fcos3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
model = dict(
type='FCOSMono3D',
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(
type='ResNet',
depth=101,
Expand All @@ -9,7 +8,10 @@
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet101_caffe')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
Expand Down
55 changes: 55 additions & 0 deletions configs/_base_/models/pgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_base_ = './fcos3d.py'
# model settings
model = dict(
bbox_head=dict(
_delete_=True,
type='PGDHead',
num_classes=10,
in_channels=256,
stacked_convs=2,
feat_channels=256,
use_direction_classifier=True,
diff_rad_by_sin=True,
pred_attrs=True,
pred_velo=True,
pred_bbox2d=True,
pred_keypoints=False,
dir_offset=0.7854, # pi/4
strides=[8, 16, 32, 64, 128],
group_reg_dims=(2, 1, 3, 1, 2), # offset, depth, size, rot, velo
cls_branch=(256, ),
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
() # velo
),
dir_branch=(256, ),
attr_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_attr=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
norm_on_bbox=True,
centerness_on_reg=True,
center_sampling=True,
conv_bias=True,
dcn_on_last_conv=True,
use_depth_classifier=True,
depth_branch=(256, ),
depth_range=(0, 50),
depth_unit=10,
division='uniform',
depth_bins=6,
bbox_coder=dict(type='PGDBBoxCoder', code_size=9)),
test_cfg=dict(nms_pre=1000, nms_thr=0.8, score_thr=0.01, max_per_img=200))
127 changes: 127 additions & 0 deletions configs/pgd/pgd_r101_caffe_fpn_gn-head_3x4_4x_kitti-mono3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
_base_ = [
'../_base_/datasets/kitti-mono3d.py', '../_base_/models/pgd.py',
'../_base_/schedules/mmdet_schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
backbone=dict(frozen_stages=0),
neck=dict(start_level=0, num_outs=4),
bbox_head=dict(
num_classes=3,
bbox_code_size=7,
pred_attrs=False,
pred_velo=False,
pred_bbox2d=True,
use_onlyreg_proj=True,
strides=(4, 8, 16, 32),
regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 1e8)),
group_reg_dims=(2, 1, 3, 1, 16,
4), # offset, depth, size, rot, kpts, bbox2d
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
(256, ), # kpts
(256, ) # bbox2d
),
centerness_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
use_depth_classifier=True,
depth_branch=(256, ),
depth_range=(0, 70),
depth_unit=10,
division='uniform',
depth_bins=8,
pred_keypoints=True,
weight_dim=1,
loss_depth=dict(
type='UncertainSmoothL1Loss', alpha=1.0, beta=3.0,
loss_weight=1.0),
bbox_coder=dict(
type='PGDBBoxCoder',
base_depths=((28.01, 16.32), ),
base_dims=((0.8, 1.73, 0.6), (1.76, 1.73, 0.6), (3.9, 1.56, 1.6)),
code_size=7)),
# set weight 1.0 for base 7 dims (offset, depth, size, rot)
# 0.2 for 16-dim keypoint offsets and 1.0 for 4-dim 2D distance targets
train_cfg=dict(code_weight=[
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 1.0, 1.0, 1.0, 1.0
]),
test_cfg=dict(nms_pre=100, nms_thr=0.05, score_thr=0.001, max_per_img=20))

class_names = ['Pedestrian', 'Cyclist', 'Car']
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='LoadAnnotations3D',
with_bbox=True,
with_label=True,
with_attr_label=False,
with_bbox_3d=True,
with_label_3d=True,
with_bbox_depth=True),
dict(type='Resize', img_scale=(1242, 375), keep_ratio=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d', 'gt_labels_3d',
'centers2d', 'depths'
]),
]
test_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='MultiScaleFlipAug',
scale_factor=1.0,
flip=False,
transforms=[
dict(type='RandomFlip3D'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img']),
])
]
data = dict(
samples_per_gpu=3,
workers_per_gpu=3,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
lr=0.001, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[32, 44])
total_epochs = 48
runner = dict(type='EpochBasedRunner', max_epochs=48)
evaluation = dict(interval=2)
checkpoint_config = dict(interval=8)
6 changes: 4 additions & 2 deletions mmdet3d/core/bbox/coders/pgd_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch
from torch.nn import functional as F

from mmdet.core.bbox.builder import BBOX_CODERS
Expand Down Expand Up @@ -45,8 +46,9 @@ def decode_2d(self,
scale_kpts = scale[3]
# 2 dimension of offsets x 8 corners of a 3D bbox
bbox[:, self.bbox_code_size:self.bbox_code_size + 16] = \
scale_kpts(clone_bbox[
:, self.bbox_code_size:self.bbox_code_size + 16]).float()
torch.tanh(scale_kpts(clone_bbox[
:, self.bbox_code_size:self.bbox_code_size + 16]).float())

if pred_bbox2d:
scale_bbox2d = scale[-1]
# The last four dimensions are offsets to four sides of a 2D bbox
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .free_anchor3d_head import FreeAnchor3DHead
from .groupfree3d_head import GroupFree3DHead
from .parta2_rpn_head import PartA2RPNHead
from .pgd_head import PGDHead
from .shape_aware_head import ShapeAwareHead
from .smoke_mono3d_head import SMOKEMono3DHead
from .ssd_3d_head import SSD3DHead
Expand All @@ -17,5 +18,5 @@
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead',
'GroupFree3DHead', 'SMOKEMono3DHead'
'GroupFree3DHead', 'SMOKEMono3DHead', 'PGDHead'
]
35 changes: 27 additions & 8 deletions mmdet3d/models/dense_heads/anchor_free_mono3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,6 @@ def __init__(
self.attr_branch = attr_branch

self._init_layers()
if init_cfg is None:
self.init_cfg = dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))

def _init_layers(self):
"""Initialize layers of the head."""
Expand Down Expand Up @@ -288,8 +281,34 @@ def _init_predictor(self):
self.conv_attr = nn.Conv2d(self.attr_branch[-1], self.num_attrs, 1)

def init_weights(self):
super().init_weights()
"""Initialize weights of the head.
We currently still use the customized defined init_weights because the
default init of DCN triggered by the init_cfg will init
conv_offset.weight, which mistakenly affects the training stability.
"""
for modules in [self.cls_convs, self.reg_convs, self.conv_cls_prev]:
for m in modules:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
for conv_reg_prev in self.conv_reg_prevs:
if conv_reg_prev is None:
continue
for m in conv_reg_prev:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
if self.use_direction_classifier:
for m in self.conv_dir_cls_prev:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
if self.pred_attrs:
for m in self.conv_attr_prev:
if isinstance(m.conv, nn.Conv2d):
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
for conv_reg in self.conv_regs:
normal_init(conv_reg, std=0.01)
if self.use_direction_classifier:
normal_init(self.conv_dir_cls, std=0.01, bias=bias_cls)
if self.pred_attrs:
Expand Down
Loading

0 comments on commit 36aa4cb

Please sign in to comment.