Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Support passing arguments to loss from head. #523

Merged
merged 1 commit into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def __init__(self,
self.compute_accuracy = Accuracy(topk=self.topk)
self.cal_acc = cal_acc

def loss(self, cls_score, gt_label):
def loss(self, cls_score, gt_label, **kwargs):
num_samples = len(cls_score)
losses = dict()
# compute loss
loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
loss = self.compute_loss(
cls_score, gt_label, avg_factor=num_samples, **kwargs)
if self.cal_acc:
# compute accuracy
acc = self.compute_accuracy(cls_score, gt_label)
Expand All @@ -55,10 +56,10 @@ def loss(self, cls_score, gt_label):
losses['loss'] = loss
return losses

def forward_train(self, cls_score, gt_label):
def forward_train(self, cls_score, gt_label, **kwargs):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

def simple_test(self, cls_score):
Expand Down
4 changes: 2 additions & 2 deletions mmcls/models/heads/linear_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def simple_test(self, x):

return self.post_process(pred)

def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
4 changes: 2 additions & 2 deletions mmcls/models/heads/multi_label_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def loss(self, cls_score, gt_label):
losses['loss'] = loss
return losses

def forward_train(self, cls_score, gt_label):
def forward_train(self, cls_score, gt_label, **kwargs):
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
gt_label = gt_label.type_as(cls_score)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

def simple_test(self, x):
Expand Down
4 changes: 2 additions & 2 deletions mmcls/models/heads/multi_label_linear_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(self,

self.fc = nn.Linear(self.in_channels, self.num_classes)

def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
gt_label = gt_label.type_as(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

def simple_test(self, x):
Expand Down
4 changes: 2 additions & 2 deletions mmcls/models/heads/stacked_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def simple_test(self, x):

return self.post_process(pred)

def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
cls_score = x
for layer in self.layers:
cls_score = layer(cls_score)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
4 changes: 2 additions & 2 deletions mmcls/models/heads/vision_transformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def simple_test(self, x):

return self.post_process(pred)

def forward_train(self, x, gt_label):
def forward_train(self, x, gt_label, **kwargs):
x = x[-1]
_, cls_token = x
cls_score = self.layers(cls_token)
losses = self.loss(cls_score, gt_label)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
7 changes: 7 additions & 0 deletions tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def test_cls_head(feat):
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0

# test ClsHead with weight
weight = torch.tensor([0.5, 0.5, 0.5, 0.5])

losses_ = head.forward_train(feat, fake_gt_label)
losses = head.forward_train(feat, fake_gt_label, weight=weight)
assert losses['loss'].item() == losses_['loss'].item() * 0.5


@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
def test_linear_head(feat):
Expand Down