Skip to content

Commit

Permalink
Implementation of simplest quantile regression loss
Browse files Browse the repository at this point in the history
  Summary:
    Fixes pytorch#38035
    Added funtional.q1_loss & loss.Q1Loss
  • Loading branch information
maxmarketit committed Oct 24, 2020
1 parent 789e935 commit 291c91b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4398,6 +4398,7 @@ def _test_loss_equal_input_target_shape(self, cast):
losses = {
'mse_loss': lambda x, y: F.mse_loss(x, y),
'l1_loss': lambda x, y: F.l1_loss(x, y),
'q1_loss': lambda x, y: F.q1_loss(x, y, q=0.5),
'smooth_l1_loss': lambda x, y: F.smooth_l1_loss(x, y),
'kl_div': lambda x, y: F.kl_div(x, y),
'poisson_nll_loss': lambda x, y: F.poisson_nll_loss(x, y),
Expand Down Expand Up @@ -7055,6 +7056,7 @@ def test_pointwise_loss_broadcast(self):
losses = {
'mse_loss': lambda x, y, r: F.mse_loss(x, y, reduction=r),
'l1_loss': lambda x, y, r: F.l1_loss(x, y, reduction=r),
'q1_loss': lambda x, y, r: F.q1_loss(x, y, q=0.5, reduction=r),
'smooth_l1_loss': lambda x, y, r: F.smooth_l1_loss(x, y, reduction=r),
}

Expand All @@ -7077,6 +7079,14 @@ def test_l1_loss_correct(self):
self.assertEqual(
torch.nn.L1Loss()(input, torch.zeros_like(input)),
input.abs().mean())

# q1_loss is the same as l1_loss when q=0.5
def test_q1_loss_correct(self):
for N in range(1, 50, 10):
input = torch.rand(N, 3, 1024, 1024)
self.assertEqual(
torch.nn.Q1Loss(q=0.5)(input, torch.zeros_like(input)),
input.abs().mean())

def test_smoothl1loss_negative_beta_not_supported(self):
with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -10453,6 +10463,7 @@ def v(fn):
v(lambda: F.kl_div(input, input, reduction=reduction))
v(lambda: F.smooth_l1_loss(input, input, reduction=reduction))
v(lambda: F.l1_loss(input, input, reduction=reduction))
v(lambda: F.q1_loss(input, input, q=0.5, reduction=reduction))
v(lambda: F.mse_loss(input, input, reduction=reduction))
v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
Expand Down
38 changes: 38 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,6 +2605,44 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))


#def q1_loss(input, target, q, size_average=None, reduce=None, reduction='mean'):
def q1_loss(input, target, q, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""q1_loss(input, target, q, reduction='mean') -> Tensor
Function that takes the mean element-wise q-quantile loss.
See :class:`~torch.nn.Q1Loss` for details.
"""
#if not torch.jit.is_scripting():
# tens_ops = (input, target)
# if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
# return handle_torch_function(
# l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce,
# reduction=reduction)
if not (target.size() == input.size()):
warnings.warn("Using a target size ({}) that is different to the input size ({}). "
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size.".format(target.size(), input.size()),
stacklevel=2)
#if size_average is not None or reduce is not None:
# reduction = _Reduction.legacy_get_string(size_average, reduce)

expanded_input, expanded_target = torch.broadcast_tensors(input, target)

e = expanded_target - expanded_input
loss = torch.max(q*e, (q-1)*e)
if reduction =='none':
return loss
elif reduction == 'mean':
return torch.mean(loss)
elif reduction == 'sum':
return torch.sum(loss)
else:
raise NotImplementedError("Unknown Reduction. Only one of 'none', 'mean', 'sum' are available.")
#return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))


def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
# type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
Expand Down
3 changes: 3 additions & 0 deletions torch/nn/functional.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def l1_loss(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., r
reduction: str = ...) -> Tensor: ...


def q1_loss(input: Tensor, target: Tensor, q: float, reduction: str = ...) -> Tensor: ...


def mse_loss(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., reduce: Optional[bool] = ...,
reduction: str = ...) -> Tensor: ...

Expand Down
59 changes: 59 additions & 0 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,65 @@ def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> N

def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.l1_loss(input, target, reduction=self.reduction)



class Q1Loss(_Loss):
r"""Creates a criterion that measures the quantile loss between each element in
the input :math:`x` and target :math:`y` with quantile :math:`q`.
For scalar :math:`q`, the unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = \operatorname{max}(q \cdot (y_n - x_n), (q-1) \cdot(y_n - x_n)),
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
The sum operation still operates over all the elements, and divides by :math:`n`.
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
Args:
q : Specifies the quantile for quantile loss
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, *)` where :math:`*` means, any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
- Output: scalar. If :attr:`reduction` is ``'none'``, then
:math:`(N, *)`, same shape as the input
Examples::
>>> loss = nn.Q1Loss(q=0.3)
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5)
>>> output = loss(input, target)
>>> output.backward()
"""
__constants__ = ['reduction']

def __init__(self, q, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super(Q1Loss, self).__init__(size_average, reduce, reduction)
self.q = q

def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.q1_loss(input, target, self.q, reduction=self.reduction)


class NLLLoss(_WeightedLoss):
Expand Down

0 comments on commit 291c91b

Please sign in to comment.