Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix misbehave of KLDivLoss (apache#18423)
* fix misbehave of KLDivLoss In the current version of KLDivLoss, the return value is not the same value calculated by SoftmaxCrossEntropyLoss, which is not documented. It may due to the incorrect settings which using mean rather than sum dealing with the return value. I provide a fix of this setting, which will keep the return value of `KLDivLoss` and SoftmaxCrossEntropyLoss` almost the same when `from_logits=False` and `sparse_label=False` are set to these functions seperately. Now, the behave of KLDivLoss is exactly the same to what the document say. ``` import mxnet as mx a=mx.nd.array([[-1,1],[1,-1]]) b=mx.nd.array([1,0]).one_hot(2) TrueLoss=mx.gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False) FalseLoss=mx.gluon.loss.KLDivLoss(from_logits=False) c=TrueLoss(a,b) d=FalseLoss(a,b)*a.shape[-1] assert((c-d).abs().sum()==0 and a.shape[-1]==2) ``` * update sdml loss the current version of SDMLLoss told us to `multiply for the number of labels` but actually it `multiply batch_size`. After this PR, it is no need to `multiply batch_size` or `multiply the number of labels` any more. * remove outdated comment
- Loading branch information