Skip to content

Commit

Permalink
feat: add loss reduction when using metric as loss
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Feb 13, 2023
1 parent 53d176f commit f03d7f1
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions pyiqa/models/inference_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import torchvision as tv

from collections import OrderedDict
from pyiqa.default_model_configs import DEFAULT_CONFIGS
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.utils.img_util import imread2tensor

from pyiqa.losses.loss_util import weight_reduce_loss


class InferenceModel(torch.nn.Module):
"""Common interface for quality inference of images with default setting of each metric."""
Expand All @@ -14,6 +15,8 @@ def __init__(
self,
metric_name,
as_loss=False,
loss_weight=None,
loss_reduction='mean',
device=None,
**kwargs # Other metric options
):
Expand All @@ -33,7 +36,10 @@ def __init__(
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device

self.as_loss = as_loss
self.loss_weight = loss_weight
self.loss_reduction = loss_reduction

# =========== define metric model ===============
net_opts = OrderedDict()
Expand Down Expand Up @@ -68,4 +74,7 @@ def forward(self, target, ref=None, **kwargs):
elif self.metric_mode == 'NR':
output = self.net(target.to(self.device))

return output
if self.as_loss:
return weight_reduce_loss(output, self.loss_weight, self.loss_reduction)
else:
return output

0 comments on commit f03d7f1

Please sign in to comment.