Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
loss for np/nd array
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Dec 31, 2019
1 parent c020f37 commit b8da16a
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 89 deletions.
203 changes: 124 additions & 79 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ def _reshape_like(F, x, y):
return F.reshape_like(x, y)


def _batch_mean(F, loss, batch_axis):
"""Return mean on the specified batch axis, not keeping the axis"""
if is_np_array():
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.mean(loss, axis=axes)
else:
return F.mean(loss, axis=batch_axis, exclude=True)

def _batch_sum(F, loss, batch_axis):
"""Return sum on the specified batch axis, not keeping the axis"""
if is_np_array():
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.sum(loss, axis=axes)
else:
return F.sum(loss, axis=batch_axis, exclude=True)



class Loss(HybridBlock):
"""Base class for loss.
Expand Down Expand Up @@ -143,16 +163,11 @@ def __init__(self, weight=1., batch_axis=0, **kwargs):
super(L2Loss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, pred, label, sample_weight=None):
square_fn = F.np.square if is_np_array() else F.square
label = _reshape_like(F, label, pred)
loss = F.np.square(label - pred) if is_np_array() else F.square(label - pred)
loss = square_fn(label - pred)
loss = _apply_weighting(F, loss, self._weight / 2, sample_weight)
if is_np_array():
if F is ndarray:
return F.np.mean(loss, axis=tuple(range(1, loss.ndim)))
else:
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class L1Loss(Loss):
Expand Down Expand Up @@ -188,16 +203,11 @@ def __init__(self, weight=None, batch_axis=0, **kwargs):
super(L1Loss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, pred, label, sample_weight=None):
abs_fn = F.np.abs if is_np_array() else F.abs
label = _reshape_like(F, label, pred)
loss = F.np.abs(label - pred) if is_np_array() else F.abs(label - pred)
loss = abs_fn(label - pred)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
if is_np_array():
if F is ndarray:
return F.np.mean(loss, axis=tuple(range(1, loss.ndim)))
else:
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class SigmoidBinaryCrossEntropyLoss(Loss):
Expand Down Expand Up @@ -263,7 +273,6 @@ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs):
self._from_sigmoid = from_sigmoid

def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None):
label = _reshape_like(F, label, pred)
if is_np_array():
relu_fn = F.npx.relu
act_fn = F.npx.activation
Expand All @@ -276,6 +285,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None):
abs_fn = F.abs
mul_fn = F.broadcast_mul
log_fn = F.log
label = _reshape_like(F, label, pred)
if not self._from_sigmoid:
if pos_weight is None:
# We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
Expand All @@ -296,13 +306,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None):
loss = -(mul_fn(log_fn(pred + eps) * label, pos_weight)
+ log_fn(1. - pred + eps) * (1. - label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
if is_np_array():
if F is ndarray:
return F.np.mean(loss, axis=tuple(range(1, loss.ndim)))
else:
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss
Expand Down Expand Up @@ -380,26 +384,20 @@ def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,

def hybrid_forward(self, F, pred, label, sample_weight=None):
if is_np_array():
log_softmax = F.npx.log_softmax
pick = F.npx.pick
log_softmax_fn = F.npx.log_softmax
pick_fn = F.npx.pick
else:
log_softmax = F.log_softmax
pick = F.pick
log_softmax_fn = F.log_softmax
pick_fn = F.pick
if not self._from_logits:
pred = log_softmax(pred, self._axis)
pred = log_softmax_fn(pred, self._axis)
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
loss = -pick_fn(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
if is_np_array():
if F is ndarray:
return loss.mean(axis=tuple(range(1, loss.ndim)))
else:
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return loss.mean(axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


SoftmaxCELoss = SoftmaxCrossEntropyLoss
Expand Down Expand Up @@ -473,11 +471,17 @@ def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0,
self._axis = axis

def hybrid_forward(self, F, pred, label, sample_weight=None):
if is_np_array():
log_softmax_fn = F.npx.log_softmax
log_fn = F.np.log
else:
log_softmax_fn = F.log_softmax
log_fn = F.log
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
loss = label * (F.log(label + 1e-12) - pred)
pred = log_softmax_fn(pred, self._axis)
loss = label * (log_fn(label + 1e-12) - pred)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class CTCLoss(Loss):
Expand Down Expand Up @@ -603,12 +607,18 @@ def __init__(self, rho=1, weight=None, batch_axis=0, **kwargs):
self._rho = rho

def hybrid_forward(self, F, pred, label, sample_weight=None):
if is_np_array():
abs_fn = F.np.abs
where_fn = F.np.where
else:
abs_fn = F.abs
where_fn = F.where
label = _reshape_like(F, label, pred)
loss = F.abs(label - pred)
loss = F.where(loss > self._rho, loss - 0.5 * self._rho,
(0.5 / self._rho) * F.square(loss))
loss = abs_fn(label - pred)
loss = where_fn(loss > self._rho, loss - 0.5 * self._rho,
(0.5 / self._rho) * F.square(loss))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class HingeLoss(Loss):
Expand Down Expand Up @@ -650,10 +660,11 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs):
self._margin = margin

def hybrid_forward(self, F, pred, label, sample_weight=None):
relu_fn = F.np.relu if is_np_array() else F.relu
label = _reshape_like(F, label, pred)
loss = F.relu(self._margin - pred * label)
loss = relu_fn(self._margin - pred * label)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class SquaredHingeLoss(Loss):
Expand Down Expand Up @@ -695,10 +706,16 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs):
self._margin = margin

def hybrid_forward(self, F, pred, label, sample_weight=None):
if is_np_array():
relu_fn = F.np.relu
square_fn = F.np.square
else:
relu_fn = F.relu
square_fn = F.square
label = _reshape_like(F, label, pred)
loss = F.square(F.relu(self._margin - pred * label))
loss = square_fn(relu_fn(self._margin - pred * label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class LogisticLoss(Loss):
Expand Down Expand Up @@ -744,14 +761,22 @@ def __init__(self, weight=None, batch_axis=0, label_format='signed', **kwargs):
% label_format)

def hybrid_forward(self, F, pred, label, sample_weight=None):
if is_np_array():
relu_fn = F.npx.relu
act_fn = F.npx.activation
abs_fn = F.np.abs
else:
relu_fn = F.relu
act_fn = F.Activation
abs_fn = F.abs
label = _reshape_like(F, label, pred)
if self._label_format == 'signed':
label = (label + 1.0) / 2.0 # Transform label to be either 0 or 1
# Use a stable formula in computation
loss = F.relu(pred) - pred * label + \
F.Activation(-F.abs(pred), act_type='softrelu')
loss = relu_fn(pred) - pred * label + \
act_fn(-abs_fn(pred), act_type='softrelu')
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
return _batch_mean(F, loss, self._batch_axis)


class TripletLoss(Loss):
Expand Down Expand Up @@ -792,11 +817,16 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs):
self._margin = margin

def hybrid_forward(self, F, pred, positive, negative):
if is_np_array():
relu_fn = F.npx.relu
square_fn = F.np.square
else:
relu_fn = F.relu
square_fn = F.square
positive = _reshape_like(F, positive, pred)
negative = _reshape_like(F, negative, pred)
loss = F.sum(F.square(positive - pred) - F.square(negative - pred),
axis=self._batch_axis, exclude=True)
loss = F.relu(loss + self._margin)
loss = _batch_sum(F, square_fn(positive - pred) - square_fn(negative - pred), self._batch_axis)
loss = relu_fn(loss + self._margin)
return _apply_weighting(F, loss, self._weight, None)


Expand Down Expand Up @@ -846,20 +876,26 @@ def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=Fal
self._compute_full = compute_full

def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
if is_np_array():
exp_fn = F.np.exp
log_fn = F.np.log
else:
exp_fn = F.exp
log_fn = F.log
target = _reshape_like(F, target, pred)
if self._from_logits:
loss = F.exp(pred) - target * pred
loss = exp_fn(pred) - target * pred
else:
loss = pred - target * F.log(pred + epsilon)
loss = pred - target * log_fn(pred + epsilon)
if self._compute_full:
# Using numpy's pi value
stirling_factor = target * \
F.log(target) - target + 0.5 * F.log(2 * target * np.pi)
log_fn(target) - target + 0.5 * log_fn(2 * target * np.pi)
target_gt_1 = target > 1
stirling_factor *= target_gt_1
loss += stirling_factor
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss)
return _batch_mean(F, loss, self._batch_axis)


class CosineEmbeddingLoss(Loss):
Expand Down Expand Up @@ -903,30 +939,39 @@ def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs):
self._margin = margin

def hybrid_forward(self, F, input1, input2, label, sample_weight=None):
if is_np_array():
where_fn = F.np.where
clip_fn = F.np.clip
else:
where_fn = F.where
clip_fn = F.clip

input1 = _reshape_like(F, input1, input2)
label = label.reshape((-1, 1))
cos_sim = self._cosine_similarity(F, input1, input2)
y_1 = label == 1
y_minus_1 = label == -1
cos_sim_a = (1 - cos_sim) * y_1
label = _reshape_like(F, label, cos_sim)
loss = where_fn(label == 1,
1 - cos_sim,
clip_fn(cos_sim - self._margin, 0, 1 - self._margin))

if F is ndarray:
z_array = F.array([0])
else:
z_array = F.zeros((1, 1))
cos_sim_b = F.broadcast_maximum(
z_array, y_minus_1 * (cos_sim - self._margin), axis=1)
loss = cos_sim_a + cos_sim_b
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return loss
return _batch_mean(F, loss, self._batch_axis)

def _cosine_similarity(self, F, x, y, axis=-1):
# Calculates the cosine similarity between 2 vectors
x_norm = F.norm(x, axis=axis).reshape(-1, 1)
y_norm = F.norm(y, axis=axis).reshape(-1, 1)
x_dot_y = F.sum(x * y, axis=axis).reshape(-1, 1)
if F is ndarray:
eps_arr = F.array([1e-12])
def _cosine_similarity(self, F, x, y):
if is_np_array():
reshape_fn = F.npx.reshape
norm_fn = F.npx.norm
sum_fn = F.np.sum
full_fn = F.np.full
max_fn = F.np.maximum
else:
eps_arr = F.full((1, 1), 1e-12)
return (x_dot_y / F.broadcast_maximum(x_norm * y_norm, eps_arr))
reshape_fn = F.reshape
norm_fn = F.norm
sum_fn = F.sum
full_fn = F.full
max_fn = F.broadcast_maximum
# Calculates the cosine similarity between 2 vectors
x_norm = reshape_fn(norm_fn(x, axis=-1), (-1, 1))
y_norm = reshape_fn(norm_fn(y, axis=-1), (-1, 1))
x_dot_y = reshape_fn(sum_fn(x * y, axis=-1), (-1, 1))
eps_arr = full_fn((1, 1), 1e-12)
return (x_dot_y / max_fn(x_norm * y_norm, eps_arr))
1 change: 1 addition & 0 deletions src/operator/tensor/broadcast_reduce_norm_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Examples::
norm(csr) = [5.47722578]
)code" ADD_FILELINE)
.add_alias("_npx_norm")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NormParam>)
Expand Down
18 changes: 8 additions & 10 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_cosine_loss():
denominator = mx.nd.sqrt(mx.nd.sum(input1**2, axis=1, keepdims=True)) \
* mx.nd.sqrt(mx.nd.sum(input2**2, axis=1, keepdims=True))
numpy_loss = mx.nd.where(label == 1, 1-numerator/denominator, \
mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1))
mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1)).reshape((-1,))
assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5)

def test_poisson_nllloss():
Expand All @@ -385,27 +385,25 @@ def test_poisson_nllloss():
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy(), axis=1)
loss_withlogits = Loss(pred, target)
assert_almost_equal(brute_loss, loss_withlogits.asscalar())
assert_almost_equal(brute_loss, loss_withlogits)

#2) Testing for flag logits = False
loss_no_logits = Loss_no_logits(pred, target)
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
if np.isnan(loss_no_logits.asscalar()):
assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
else:
assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08),
axis=1)
assert_almost_equal(np_loss_no_logits, loss_no_logits.asnumpy())

#3) Testing for Sterling approximation
shape=(2, 3)
np_pred = np.random.uniform(1, 5, shape)
np_target = np.random.uniform(1, 5, shape)
np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1)
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
assert_almost_equal(np_compute_full, loss_compute_full.asscalar())
assert_almost_equal(np_compute_full, loss_compute_full)

@with_seed()
def test_poisson_nllloss_mod():
Expand Down

0 comments on commit b8da16a

Please sign in to comment.