Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deprecated for softmax_with_cross_entropy #31722

Conversation

chajchaj
Copy link
Contributor

@chajchaj chajchaj commented Mar 18, 2021

PR types

Function optimization

PR changes

APIs

Describe

1、softmax_with_cross_entropy是历史原因保留的api,2.0下并不推荐用户使用,所以增加打印deprecated警告
2、完善nn/functional/loss.py中cross_entropy的英文文档
3、完善nn/layer/loss.py中CrossEntropyLoss的英文文档
4、修复softlabel为true,并指定weight参数时cross_entropy报错问题
5、python端加开关softmax_switch

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -1162,7 +1162,8 @@ def softmax_with_cross_entropy(logits,
ignore_index=kIgnoreIndex,
numeric_stable_mode=True,
return_softmax=False,
axis=-1):
axis=-1,
softmax_switch=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax_with_cross_entropy是待废弃的api,不用升级,建议不修改这部分fluid下的python代码。
新功能可以在paddle.nn.functional.loss.py文件,通过调用底层core.ops.softmax_with_cross_entropy和append op的方式实现,避免对待废弃代码的依赖。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -1119,94 +1137,203 @@ def cross_entropy(input,
reduction='mean',
soft_label=False,
axis=-1,
softmax_switch=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch这个单词有转换的意思,在这里可能不太直观;
建议用use_softmax=True或者enable_softmax=True,类似use_cudnn,use_global_stat

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档中的问题还麻烦改进。

Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
to provide a more numerically stable computing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了use_softmax这个参数后,这一段描述也相应的更新一下吧。这里只描述了default的行为。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新

The equation is as follows:
This operator can be used to calculate the softmax cross entropy loss with soft and hard labels.
Where, the hard labels mean the actual label value, 0, 1, 2, etc. And the soft labels
mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

碰到过好几个用户说看不懂这里是在说什么意思。看看还能不能优化一下,至少体现出来hard label时的值的范围是[0, C-1], soft label的时候给的李孜的和为1.0。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新


loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
1. Hard label (each sample can only be assigned into one category)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_softmax=False时的公式,是否要体现出来。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加

:math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes , ``k >= 1`` .
Note: it expects unscaled logits. This operator should not be used with the
output of softmax operator, which will produce incorrect results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_softmax = False时的输入,可以是softmax的结果。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加

reduction='mean'
input_np = np.random.random([N, C]).astype(np.float64)
label_np = np.random.randint(0, C, size=(N)).astype(np.int64)
weight_np = np.random.random([C]).astype(np.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

构造示例的时候可以直接用paddle.rand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


2. reduction

2.1 if the ``reduction`` parameter is ``none``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

none建议改为None。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改weight为None

labels = np.random.uniform(0.1, 1.0, shape).astype(dtype)
labels /= np.sum(labels, axis=axis, keepdims=True)
paddle.set_device("cpu")
paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要调用paddle.disable_static?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

XiaoguangHu01
XiaoguangHu01 previously approved these changes Mar 26, 2021
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

chalsliu
chalsliu previously approved these changes Mar 26, 2021
.. code-block:: python

import paddle
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行不需要吧。

shape = [N, C]
reduction='mean'
weight = None
logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0, seed=99999)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥没像example1一样,用paddle.seed设置种子呢?

logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0, seed=99999)
labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0, seed=99999)
labels /= paddle.sum(labels, axis=axis, keepdim=True)
paddle.set_device("cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行是不是也不需要?

- **soft_label** (bool, optional)

Indicate whether label is soft.
If soft_label=False, the label is hard. If soft_label=True, the label is soft.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句话好奇怪呀。

- **axis** (int, optional)

The index of dimension to perform softmax calculations.
It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the rank of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rank 可改为number of dimensions 或者缩写ndim。(可参考API的规范)

raindrops2sea
raindrops2sea previously approved these changes Mar 28, 2021
@chajchaj chajchaj dismissed stale reviews from raindrops2sea, chalsliu, and XiaoguangHu01 via be38f31 March 29, 2021 02:34
jzhang533
jzhang533 previously approved these changes Mar 29, 2021
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@XiaoguangHu01 XiaoguangHu01 merged commit 73a6fa3 into PaddlePaddle:develop Mar 30, 2021
chajchaj added a commit to chajchaj/Paddle that referenced this pull request Mar 30, 2021
* add deprecated for softmax_with_cross_entropy, test=develop

* test for deprecated in english doc, test=develop

* test deprecated for softmax_with_cross_entropy in english doc, test=develop

* fix readme and English doc for cross_entropy, test=develop

* rm test for softmax_with_cross_entropy deprecated, test=develop

* update readme for CrossEntropyLoss, test=develop

* fix readme format, test=develop

* fix readme format, test=develop

* fix readme format for cross_entropy, test=develop

* add softmax_switch and fix softlabel for cross_entropy, test=develop

* 1)recovery softmax_with_cross_entropy in fluid 2) change softmax_switch to use_softmax 3) add example for softlabel for cross_entropy, test=develop

* fix Example number for cross_entropy, test=develop

* fix code format, test=develop

* fix for CI-Coverage, test=develop

* fix for CI-Coverage, test=develop

* fix ci-coverage for Non-ASCII character '\xe2' in file, test=develop

* fix ci-coverage for Non-ASCII character '\xe2' in nn.layer.loss.py, test=develop

* update description for doc when use_softmax=Fasle, test=develop

* fix some docs and code example for cross_entropy, test=develop

* delete redundant description for soft_label parameter of cross_entropy, test=develop

* fix some comment for test_cross_entropy_loss.py, test=develop
XiaoguangHu01 pushed a commit that referenced this pull request Mar 30, 2021
* add deprecated for softmax_with_cross_entropy (#31722)

* add deprecated for softmax_with_cross_entropy, test=develop

* test for deprecated in english doc, test=develop

* test deprecated for softmax_with_cross_entropy in english doc, test=develop

* fix readme and English doc for cross_entropy, test=develop

* rm test for softmax_with_cross_entropy deprecated, test=develop

* update readme for CrossEntropyLoss, test=develop

* fix readme format, test=develop

* fix readme format, test=develop

* fix readme format for cross_entropy, test=develop

* add softmax_switch and fix softlabel for cross_entropy, test=develop

* 1)recovery softmax_with_cross_entropy in fluid 2) change softmax_switch to use_softmax 3) add example for softlabel for cross_entropy, test=develop

* fix Example number for cross_entropy, test=develop

* fix code format, test=develop

* fix for CI-Coverage, test=develop

* fix for CI-Coverage, test=develop

* fix ci-coverage for Non-ASCII character '\xe2' in file, test=develop

* fix ci-coverage for Non-ASCII character '\xe2' in nn.layer.loss.py, test=develop

* update description for doc when use_softmax=Fasle, test=develop

* fix some docs and code example for cross_entropy, test=develop

* delete redundant description for soft_label parameter of cross_entropy, test=develop

* fix some comment for test_cross_entropy_loss.py, test=develop

* cherry-pick for add deprecated,test=develop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants