diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 852a9a791d53..bc447b0f1c55 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -74,6 +74,26 @@ def _reshape_like(F, x, y): return F.reshape_like(x, y) +def _batch_mean(F, loss, batch_axis): + """Return mean on the specified batch axis, not keeping the axis""" + if is_np_array(): + axes = list(range(loss.ndim)) + del axes[batch_axis] + return F.np.mean(loss, axis=axes) + else: + return F.mean(loss, axis=batch_axis, exclude=True) + +def _batch_sum(F, loss, batch_axis): + """Return sum on the specified batch axis, not keeping the axis""" + if is_np_array(): + axes = list(range(loss.ndim)) + del axes[batch_axis] + return F.np.sum(loss, axis=axes) + else: + return F.sum(loss, axis=batch_axis, exclude=True) + + + class Loss(HybridBlock): """Base class for loss. @@ -142,16 +162,11 @@ def __init__(self, weight=1., batch_axis=0, **kwargs): super(L2Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, pred, label, sample_weight=None): + square_fn = F.np.square if is_np_array() else F.square label = _reshape_like(F, label, pred) - loss = F.np.square(label - pred) if is_np_array() else F.square(label - pred) + loss = square_fn(label - pred) loss = _apply_weighting(F, loss, self._weight / 2, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class L1Loss(Loss): @@ -187,16 +202,11 @@ def __init__(self, weight=None, batch_axis=0, **kwargs): super(L1Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, pred, label, sample_weight=None): + abs_fn = F.np.abs if is_np_array() else F.abs label = _reshape_like(F, label, pred) - loss = F.np.abs(label - pred) if is_np_array() else F.abs(label - pred) + loss = abs_fn(label - pred) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class SigmoidBinaryCrossEntropyLoss(Loss): @@ -262,7 +272,6 @@ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs): self._from_sigmoid = from_sigmoid def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): - label = _reshape_like(F, label, pred) if is_np_array(): relu_fn = F.npx.relu act_fn = F.npx.activation @@ -275,6 +284,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): abs_fn = F.abs mul_fn = F.broadcast_mul log_fn = F.log + label = _reshape_like(F, label, pred) if not self._from_sigmoid: if pos_weight is None: # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x))) @@ -295,13 +305,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): loss = -(mul_fn(log_fn(pred + eps) * label, pos_weight) + log_fn(1. - pred + eps) * (1. - label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss @@ -379,26 +383,20 @@ def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None, def hybrid_forward(self, F, pred, label, sample_weight=None): if is_np_array(): - log_softmax = F.npx.log_softmax - pick = F.npx.pick + log_softmax_fn = F.npx.log_softmax + pick_fn = F.npx.pick else: - log_softmax = F.log_softmax - pick = F.pick + log_softmax_fn = F.log_softmax + pick_fn = F.pick if not self._from_logits: - pred = log_softmax(pred, self._axis) + pred = log_softmax_fn(pred, self._axis) if self._sparse_label: - loss = -pick(pred, label, axis=self._axis, keepdims=True) + loss = -pick_fn(pred, label, axis=self._axis, keepdims=True) else: label = _reshape_like(F, label, pred) loss = -(pred * label).sum(axis=self._axis, keepdims=True) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return loss.mean(axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return loss.mean(axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) SoftmaxCELoss = SoftmaxCrossEntropyLoss @@ -472,11 +470,17 @@ def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0, self._axis = axis def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + log_softmax_fn = F.npx.log_softmax + log_fn = F.np.log + else: + log_softmax_fn = F.log_softmax + log_fn = F.log if not self._from_logits: - pred = F.log_softmax(pred, self._axis) - loss = label * (F.log(label + 1e-12) - pred) + pred = log_softmax_fn(pred, self._axis) + loss = label * (log_fn(label + 1e-12) - pred) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.sum(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class CTCLoss(Loss): @@ -549,14 +553,20 @@ def __init__(self, layout='NTC', label_layout='NT', weight=None, **kwargs): def hybrid_forward(self, F, pred, label, pred_lengths=None, label_lengths=None, sample_weight=None): + if is_np_array(): + swapaxes_fn = F.np.swapaxes + ctc_fn = F.npx.ctc_loss + else: + swapaxes_fn = F.swapaxes + ctc_fn = F.ctc_loss if self._layout == 'NTC': - pred = F.swapaxes(pred, 0, 1) + pred = swapaxes_fn(pred, 0, 1) if self._batch_axis == 1: - label = F.swapaxes(label, 0, 1) - loss = F.CTCLoss(pred, label, pred_lengths, label_lengths, - use_data_lengths=pred_lengths is not None, - use_label_lengths=label_lengths is not None, - blank_label='last') + label = swapaxes_fn(label, 0, 1) + loss = ctc_fn(pred, label, pred_lengths, label_lengths, + use_data_lengths=pred_lengths is not None, + use_label_lengths=label_lengths is not None, + blank_label='last') return _apply_weighting(F, loss, self._weight, sample_weight) @@ -602,12 +612,20 @@ def __init__(self, rho=1, weight=None, batch_axis=0, **kwargs): self._rho = rho def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + abs_fn = F.np.abs + where_fn = F.np.where + square_fn = F.np.square + else: + abs_fn = F.abs + where_fn = F.where + square_fn = F.square label = _reshape_like(F, label, pred) - loss = F.abs(label - pred) - loss = F.where(loss > self._rho, loss - 0.5 * self._rho, - (0.5 / self._rho) * F.square(loss)) + loss = abs_fn(label - pred) + loss = where_fn(loss > self._rho, loss - 0.5 * self._rho, + (0.5 / self._rho) * square_fn(loss)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class HingeLoss(Loss): @@ -649,10 +667,11 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): self._margin = margin def hybrid_forward(self, F, pred, label, sample_weight=None): + relu_fn = F.npx.relu if is_np_array() else F.relu label = _reshape_like(F, label, pred) - loss = F.relu(self._margin - pred * label) + loss = relu_fn(self._margin - pred * label) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class SquaredHingeLoss(Loss): @@ -694,10 +713,16 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): self._margin = margin def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + relu_fn = F.npx.relu + square_fn = F.np.square + else: + relu_fn = F.relu + square_fn = F.square label = _reshape_like(F, label, pred) - loss = F.square(F.relu(self._margin - pred * label)) + loss = square_fn(relu_fn(self._margin - pred * label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class LogisticLoss(Loss): @@ -743,14 +768,22 @@ def __init__(self, weight=None, batch_axis=0, label_format='signed', **kwargs): % label_format) def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + relu_fn = F.npx.relu + act_fn = F.npx.activation + abs_fn = F.np.abs + else: + relu_fn = F.relu + act_fn = F.Activation + abs_fn = F.abs label = _reshape_like(F, label, pred) if self._label_format == 'signed': label = (label + 1.0) / 2.0 # Transform label to be either 0 or 1 # Use a stable formula in computation - loss = F.relu(pred) - pred * label + \ - F.Activation(-F.abs(pred), act_type='softrelu') + loss = relu_fn(pred) - pred * label + \ + act_fn(-abs_fn(pred), act_type='softrelu') loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class TripletLoss(Loss): @@ -790,13 +823,18 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): super(TripletLoss, self).__init__(weight, batch_axis, **kwargs) self._margin = margin - def hybrid_forward(self, F, pred, positive, negative): + def hybrid_forward(self, F, pred, positive, negative, sample_weight=None): + if is_np_array(): + relu_fn = F.npx.relu + square_fn = F.np.square + else: + relu_fn = F.relu + square_fn = F.square positive = _reshape_like(F, positive, pred) negative = _reshape_like(F, negative, pred) - loss = F.sum(F.square(positive - pred) - F.square(negative - pred), - axis=self._batch_axis, exclude=True) - loss = F.relu(loss + self._margin) - return _apply_weighting(F, loss, self._weight, None) + loss = _batch_sum(F, square_fn(positive - pred) - square_fn(negative - pred), self._batch_axis) + loss = relu_fn(loss + self._margin) + return _apply_weighting(F, loss, self._weight, sample_weight) class PoissonNLLLoss(Loss): @@ -845,20 +883,26 @@ def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=Fal self._compute_full = compute_full def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08): + if is_np_array(): + exp_fn = F.np.exp + log_fn = F.np.log + else: + exp_fn = F.exp + log_fn = F.log target = _reshape_like(F, target, pred) if self._from_logits: - loss = F.exp(pred) - target * pred + loss = exp_fn(pred) - target * pred else: - loss = pred - target * F.log(pred + epsilon) + loss = pred - target * log_fn(pred + epsilon) if self._compute_full: # Using numpy's pi value stirling_factor = target * \ - F.log(target) - target + 0.5 * F.log(2 * target * np.pi) + log_fn(target) - target + 0.5 * log_fn(2 * target * np.pi) target_gt_1 = target > 1 stirling_factor *= target_gt_1 loss += stirling_factor loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss) + return _batch_mean(F, loss, self._batch_axis) class CosineEmbeddingLoss(Loss): @@ -902,33 +946,42 @@ def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs): self._margin = margin def hybrid_forward(self, F, input1, input2, label, sample_weight=None): + if is_np_array(): + where_fn = F.np.where + clip_fn = F.np.clip + else: + where_fn = F.where + clip_fn = F.clip + input1 = _reshape_like(F, input1, input2) - label = label.reshape((-1, 1)) cos_sim = self._cosine_similarity(F, input1, input2) - y_1 = label == 1 - y_minus_1 = label == -1 - cos_sim_a = (1 - cos_sim) * y_1 + label = _reshape_like(F, label, cos_sim) + loss = where_fn(label == 1, + 1 - cos_sim, + clip_fn(cos_sim - self._margin, 0, 1 - self._margin)) - if F is ndarray: - z_array = F.array([0]) - else: - z_array = F.zeros((1, 1)) - cos_sim_b = F.broadcast_maximum( - z_array, y_minus_1 * (cos_sim - self._margin), axis=1) - loss = cos_sim_a + cos_sim_b loss = _apply_weighting(F, loss, self._weight, sample_weight) - return loss + return _batch_mean(F, loss, self._batch_axis) def _cosine_similarity(self, F, x, y, axis=-1): - # Calculates the cosine similarity between 2 vectors - x_norm = F.norm(x, axis=axis).reshape((-1, 1)) - y_norm = F.norm(y, axis=axis).reshape((-1, 1)) - x_dot_y = F.sum(x * y, axis=axis).reshape((-1, 1)) - if F is ndarray: - eps_arr = F.array([1e-12]) + if is_np_array(): + reshape_fn = F.npx.reshape + norm_fn = F.npx.norm + sum_fn = F.np.sum + full_fn = F.np.full + max_fn = F.np.maximum else: - eps_arr = F.full((1, 1), 1e-12) - return (x_dot_y / F.broadcast_maximum(x_norm * y_norm, eps_arr)) + reshape_fn = F.reshape + norm_fn = F.norm + sum_fn = F.sum + full_fn = F.full + max_fn = F.broadcast_maximum + # Calculates the cosine similarity between 2 vectors + x_norm = reshape_fn(norm_fn(x, axis=axis), (-1, 1)) + y_norm = reshape_fn(norm_fn(y, axis=axis), (-1, 1)) + x_dot_y = reshape_fn(sum_fn(x * y, axis=axis), (-1, 1)) + eps_arr = full_fn((1, 1), 1e-12) + return (x_dot_y / max_fn(x_norm * y_norm, eps_arr)) class SDMLLoss(Loss): @@ -972,18 +1025,24 @@ def __init__(self, smoothing_parameter=0.3, weight=1., batch_axis=0, **kwargs): self.kl_loss = KLDivLoss(from_logits=True) self.smoothing_parameter = smoothing_parameter # Smoothing probability mass - def _compute_distances(self, x1, x2): + def _compute_distances(self, F, x1, x2): """ This function computes the euclidean distance between every vector in the two batches in input. """ + if is_np_array(): + expand_dims_fn = F.np.expand_dims + broadcast_to_fn = F.np.broadcast_to + else: + expand_dims_fn = F.expand_dims + broadcast_to_fn = F.broadcast_to # extracting sizes expecting [batch_size, dim] assert x1.shape == x2.shape batch_size, dim = x1.shape # expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim] - x1_ = x1.expand_dims(1).broadcast_to([batch_size, batch_size, dim]) - x2_ = x2.expand_dims(0).broadcast_to([batch_size, batch_size, dim]) + x1_ = broadcast_to_fn(expand_dims_fn(x1, 1), [batch_size, batch_size, dim]) + x2_ = broadcast_to_fn(expand_dims_fn(x2, 0), [batch_size, batch_size, dim]) # pointwise squared differences squared_diffs = (x1_ - x2_)**2 # sum of squared differences distance @@ -1015,7 +1074,7 @@ def _compute_labels(self, F, batch_size): return labels - def _loss(self, F, x1, x2): + def hybrid_forward(self, F, x1, x2): """ the function computes the kl divergence between the negative distances (internally it compute a softmax casting into probabilities) and the @@ -1033,15 +1092,15 @@ def _loss(self, F, x1, x2): learn to predict french president comparing it with all the other vectors in batch 2 """ + if is_np_array(): + log_softmax_fn = F.npx.log_softmax + else: + log_softmax_fn = F.log_softmax batch_size = x1.shape[0] labels = self._compute_labels(F, batch_size) - distances = self._compute_distances(x1, x2) - log_probabilities = F.log_softmax(-distances, axis=1) + distances = self._compute_distances(F, x1, x2) + log_probabilities = log_softmax_fn(-distances, axis=1) # multiply for the number of labels to obtain the correct loss (gluon kl_loss averages instead of sum) # PR#18423:multiply for the number of labels should multiply x1.shape[1] rather than x1.shape[0]) # After PR#18423, it is no need to multiply it anymore. return self.kl_loss(log_probabilities, labels.as_in_context(distances.context)) - - - def hybrid_forward(self, F, x1, x2): - return self._loss(F, x1, x2) diff --git a/src/operator/nn/ctc_loss.cc b/src/operator/nn/ctc_loss.cc index 096ef8c0d7b4..59f89f0576f3 100644 --- a/src/operator/nn/ctc_loss.cc +++ b/src/operator/nn/ctc_loss.cc @@ -50,6 +50,7 @@ DMLC_REGISTER_PARAMETER(CTCLossOpParam); NNVM_REGISTER_OP(CTCLoss) .add_alias("ctc_loss") +.add_alias("_npx_ctc_loss") .add_alias("_contrib_CTCLoss") .add_alias("_contrib_ctc_loss") .describe(R"code(Connectionist Temporal Classification Loss. diff --git a/src/operator/nn/ctc_loss.cu b/src/operator/nn/ctc_loss.cu index a4491bf6986e..c6952d33c86a 100644 --- a/src/operator/nn/ctc_loss.cu +++ b/src/operator/nn/ctc_loss.cu @@ -51,6 +51,7 @@ namespace op { NNVM_REGISTER_OP(CTCLoss) .add_alias("ctc_loss") +.add_alias("_npx_ctc_loss") .add_alias("_contrib_ctc_loss") .add_alias("_contrib_CTCLoss") .set_attr("FCompute", CTCLossOpForward); diff --git a/src/operator/tensor/broadcast_reduce_norm_value.cc b/src/operator/tensor/broadcast_reduce_norm_value.cc index 557c4d9e7746..43a22b2a8314 100644 --- a/src/operator/tensor/broadcast_reduce_norm_value.cc +++ b/src/operator/tensor/broadcast_reduce_norm_value.cc @@ -87,6 +87,7 @@ Examples:: norm(csr) = [5.47722578] )code" ADD_FILELINE) +.add_alias("_npx_norm") .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 9c1d496d715e..80d8cde5050f 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -56,16 +56,6 @@ def test_loss_ndarray(): assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4) -def get_net(num_hidden, flatten=True): - data = mx.symbol.Variable('data') - fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128, flatten=flatten) - act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") - fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64, flatten=flatten) - act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") - fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=num_hidden, flatten=flatten) - return fc3 - - @with_seed() def test_bce_equal_ce2(): N = 100 @@ -163,7 +153,7 @@ def test_cosine_loss(): denominator = mx.nd.sqrt(mx.nd.sum(input1**2, axis=1, keepdims=True)) \ * mx.nd.sqrt(mx.nd.sum(input2**2, axis=1, keepdims=True)) numpy_loss = mx.nd.where(label == 1, 1-numerator/denominator, \ - mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1)) + mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1)).reshape((-1,)) assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5) @xfail_when_nonstandard_decimal_separator @@ -186,25 +176,23 @@ def test_poisson_nllloss(): #Calculating by brute formula for default value of from_logits = True # 1) Testing for flag logits = True - brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy()) + brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy(), axis=1) loss_withlogits = Loss(pred, target) - assert_almost_equal(brute_loss, loss_withlogits.asscalar()) + assert_almost_equal(brute_loss, loss_withlogits) #2) Testing for flag logits = False loss_no_logits = Loss_no_logits(pred, target) - np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08)) - if np.isnan(loss_no_logits.asscalar()): - assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar())) - else: - assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar()) + np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08), + axis=1) + assert_almost_equal(np_loss_no_logits, loss_no_logits.asnumpy()) #3) Testing for Sterling approximation shape=(2, 3) np_pred = np.random.uniform(1, 5, shape) np_target = np.random.uniform(1, 5, shape) np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\ - np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1))) + np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1) Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True) loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target)) - assert_almost_equal(np_compute_full, loss_compute_full.asscalar()) + assert_almost_equal(np_compute_full, loss_compute_full) diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py new file mode 100644 index 000000000000..6c63546f85b1 --- /dev/null +++ b/tests/python/unittest/test_numpy_loss.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +from mxnet import gluon, autograd +from mxnet.test_utils import assert_almost_equal, default_context, use_np +from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator +import unittest + + +@xfail_when_nonstandard_decimal_separator +@with_seed() +@use_np +def test_loss_np_ndarray(): + output = mx.np.array([1, 2, 3, 4]) + label = mx.np.array([1, 3, 5, 7]) + weighting = mx.np.array([0.5, 1, 0.5, 1]) + + loss = gluon.loss.L1Loss() + assert mx.np.sum(loss(output, label)) == 6. + loss = gluon.loss.L1Loss(weight=0.5) + assert mx.np.sum(loss(output, label)) == 3. + loss = gluon.loss.L1Loss() + assert mx.np.sum(loss(output, label, weighting)) == 5. + + loss = gluon.loss.L2Loss() + assert mx.np.sum(loss(output, label)) == 7. + loss = gluon.loss.L2Loss(weight=0.25) + assert mx.np.sum(loss(output, label)) == 1.75 + loss = gluon.loss.L2Loss() + assert mx.np.sum(loss(output, label, weighting)) == 6 + + loss = gluon.loss.HuberLoss() + assert mx.np.sum(loss(output, label)) == 4.5 + loss = gluon.loss.HuberLoss(weight=0.25) + assert mx.np.sum(loss(output, label)) == 1.125 + loss = gluon.loss.HuberLoss() + assert mx.np.sum(loss(output, label, weighting)) == 3.75 + + loss = gluon.loss.HingeLoss(margin=10) + assert mx.np.sum(loss(output, label)) == 13. + loss = gluon.loss.HingeLoss(margin=8, weight=0.25) + assert mx.np.sum(loss(output, label)) == 2.25 + loss = gluon.loss.HingeLoss(margin=7) + assert mx.np.sum(loss(output, label, weighting)) == 4. + + loss = gluon.loss.SquaredHingeLoss(margin=10) + assert mx.np.sum(loss(output, label)) == 97. + loss = gluon.loss.SquaredHingeLoss(margin=8, weight=0.25) + assert mx.np.sum(loss(output, label)) == 13.25 + loss = gluon.loss.SquaredHingeLoss(margin=7) + assert mx.np.sum(loss(output, label, weighting)) == 19. + + loss = gluon.loss.TripletLoss(margin=10) + assert mx.np.sum(loss(output, label, -label)) == 6. + loss = gluon.loss.TripletLoss(margin=8, weight=0.25) + assert mx.np.sum(loss(output, label, -label)) == 1. + loss = gluon.loss.TripletLoss(margin=7) + assert mx.np.sum(loss(output, label, -label, weighting)) == 1.5 + + output = mx.np.array([[0, 2], [1, 4]]) + label = mx.np.array([0, 1]) + weighting = mx.np.array([[0.5], [1.0]]) + + loss = gluon.loss.SoftmaxCrossEntropyLoss() + L = loss(output, label).asnumpy() + assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4) + + L = loss(output, label, weighting).asnumpy() + assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4) + + +@with_seed() +@use_np +def test_bce_equal_ce2(): + N = 100 + loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True) + loss2 = gluon.loss.SoftmaxCELoss(from_logits=True) + out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1)) + out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8) + label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1))) + assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy()) + +@use_np +def test_logistic_loss_equal_bce(): + N = 100 + loss_binary = gluon.loss.LogisticLoss(label_format='binary') + loss_signed = gluon.loss.LogisticLoss(label_format='signed') + loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False) + data = mx.np.random.uniform(-10, 10, size=(N, 1)) + label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1))) + assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6) + assert_almost_equal(loss_signed(data, 2 * label - 1), loss_bce(data, label), atol=1e-6) + + +@with_seed() +@use_np +def test_ctc_loss(): + loss = gluon.loss.CTCLoss() + l = loss(mx.np.ones((2,20,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss(layout='TNC') + l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN') + l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]).T) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss() + l = loss(mx.np.ones((2,20,4)), mx.np.array([[2,1,2,2],[3,2,2,2]]), None, mx.np.array([2,3])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss() + l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,-1,-1],[3,2,2,-1]]), mx.np.array([20,20])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss() + l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,3,3],[3,2,2,3]]), mx.np.array([20,20]), mx.np.array([2,3])) + assert_almost_equal(l, np.array([18.82820702, 16.50581741])) + + +@xfail_when_nonstandard_decimal_separator +@with_seed() +@use_np +def test_sdml_loss(): + + N = 5 # number of samples + DIM = 10 # Dimensionality + EPOCHS = 20 + + # Generate randomized data and 'positive' samples + data = mx.np.random.uniform(-1, 1, size=(N, DIM)) + pos = data + mx.np.random.uniform(-0.1, 0.1, size=(N, DIM)) # correlated paired data + data_iter = mx.io.NDArrayIter({'data' : data, 'pos' : pos}, batch_size=N) + + # Init model and trainer + sdml_loss = gluon.loss.SDMLLoss() + model = gluon.nn.Dense(DIM, activation='tanh') # Simple NN encoder + model.initialize(mx.init.Xavier(), ctx=mx.current_context()) + trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate' : 0.1}) + + for i in range(EPOCHS): # Training loop + data_iter.reset() + for iter_batch in data_iter: + batch = [datum.as_in_ctx(mx.current_context()).as_np_ndarray() for datum in iter_batch.data] + with autograd.record(): + data, pos = batch + z_data, z_pos = model(data), model(pos) + loss = sdml_loss(z_data, z_pos) + loss.backward() + trainer.step(1) + + # After training euclidean distance between aligned pairs should be lower than all non-aligned pairs + avg_loss = loss.sum()/len(loss) + assert(avg_loss < 0.05) + +@with_seed() +@use_np +def test_cosine_loss(): + #Generating samples + input1 = mx.np.random.randn(3, 2) + input2 = mx.np.random.randn(3, 2) + label = mx.np.sign(mx.np.random.randn(input1.shape[0])) + #Calculating loss from cosine embedding loss function in Gluon + Loss = gluon.loss.CosineEmbeddingLoss() + loss = Loss(input1, input2, label) + + # Calculating the loss Numpy way + numerator = mx.np.sum(input1 * input2, keepdims=True, axis=1) + denominator = mx.np.sqrt(mx.np.sum(input1**2, axis=1, keepdims=True)) \ + * mx.np.sqrt(mx.np.sum(input2**2, axis=1, keepdims=True)) + x = numerator/denominator + label = mx.npx.reshape(label, (-1, 1)) + numpy_loss = mx.npx.reshape( + mx.np.where(label == 1, 1-x, mx.npx.relu(x)), (-1,)) + assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5) + +@xfail_when_nonstandard_decimal_separator +@use_np +def test_poisson_nllloss(): + shape=(3, 4) + not_axis0 = tuple(range(1, len(shape))) + pred = mx.np.random.normal(size=shape) + min_pred = mx.np.min(pred) + #This is necessary to ensure only positive random values are generated for prediction, + # to avoid ivalid log calculation + pred[:] = pred + mx.np.abs(min_pred) + target = mx.np.random.normal(size=shape) + min_target = mx.np.min(target) + #This is necessary to ensure only positive random values are generated for prediction, + # to avoid ivalid log calculation + target[:] += mx.np.abs(min_target) + + Loss = gluon.loss.PoissonNLLLoss(from_logits=True) + Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False) + #Calculating by brute formula for default value of from_logits = True + + # 1) Testing for flag logits = True + brute_loss = mx.np.mean(mx.np.exp(pred) - target * pred, axis=1) + loss_withlogits = Loss(pred, target) + assert_almost_equal(brute_loss, loss_withlogits) + + #2) Testing for flag logits = False + loss_no_logits = Loss_no_logits(pred, target) + np_loss_no_logits = mx.np.mean(pred - target * mx.np.log(pred + 1e-08), + axis=1) + assert_almost_equal(np_loss_no_logits, loss_no_logits) + + #3) Testing for Sterling approximation + shape=(2, 3) + np_pred = mx.np.random.uniform(1, 5, shape) + np_target = mx.np.random.uniform(1, 5, shape) + np_compute_full = mx.np.mean((np_pred - np_target * mx.np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\ + np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1) + Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True) + loss_compute_full = Loss_compute_full(np_pred, np_target) + assert_almost_equal(np_compute_full, loss_compute_full) +