Skip to content

Commit

Permalink
update test_lad_head
Browse files Browse the repository at this point in the history
  • Loading branch information
thuyngch committed Nov 16, 2021
1 parent 7732087 commit c70824e
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions tests/test_models/test_dense_heads/test_lad_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def score_samples(self, loss):
allowed_border=-1,
pos_weight=-1,
debug=False))

# since Focal Loss is not supported on CPU
self = LADHead(
num_classes=4,
in_channels=1,
Expand All @@ -55,7 +55,6 @@ def score_samples(self, loss):
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))

teacher_model = LADHead(
num_classes=4,
in_channels=1,
Expand All @@ -65,7 +64,6 @@ def score_samples(self, loss):
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))

feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
Expand Down Expand Up @@ -120,12 +118,25 @@ def score_samples(self, loss):
assert len(results) == n
assert results[0].size() == (h * w * 5, c)
assert self.with_score_voting

self = LADHead(
num_classes=4,
in_channels=1,
train_cfg=train_cfg,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))
cls_scores = [torch.ones(2, 4, 5, 5)]
bbox_preds = [torch.ones(2, 4, 5, 5)]
iou_preds = [torch.ones(2, 1, 5, 5)]
mlvl_anchors = [torch.ones(2, 5 * 5, 4)]
img_shape = None
scale_factor = [0.5, 0.5]
cfg = mmcv.Config(
dict(
nms_pre=1000,
Expand All @@ -134,12 +145,12 @@ def score_samples(self, loss):
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
rescale = False
self._get_bboxes(
cls_scores,
bbox_preds,
iou_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=rescale)
self.get_bboxes(
cls_scores, bbox_preds, iou_preds, img_metas, cfg, rescale=rescale)


# ------------------------------------------------------------------------------
# Main execution
# ------------------------------------------------------------------------------
if __name__ == '__main__':
test_lad_head_loss()

0 comments on commit c70824e

Please sign in to comment.