Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] binary_cross_entropy_with_logits forward decomp #61613

Merged
merged 30 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b7650fa
sigmoid_cross_entropy_with_logits forward decomp
zeroRains Feb 5, 2024
b05d121
mean_all forward decomp
zeroRains Feb 5, 2024
013cf4f
add the test case for binary_cross_entropy_with_logits
zeroRains Feb 6, 2024
976ac31
creat a new test file
zeroRains Feb 6, 2024
617aa99
modify the assert method
zeroRains Feb 6, 2024
4ce41f5
fix conflict
zeroRains Feb 29, 2024
4950de4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains Feb 29, 2024
7395204
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains Mar 17, 2024
cb9c8a0
modify the test
zeroRains Mar 18, 2024
135b544
fix code style
zeroRains Mar 18, 2024
b72d42d
add prim in check grad for test and handle the optional tensor
zeroRains Mar 18, 2024
b1532ac
fix conflict
zeroRains Mar 22, 2024
fb315d2
fix conflict
zeroRains Mar 22, 2024
b9d874f
fix conflict
zeroRains Mar 29, 2024
02a9896
do not modify the third_party package
zeroRains Mar 29, 2024
c62da1c
fix conflict
zeroRains Apr 25, 2024
ae5b59f
fix merge bug
zeroRains Apr 25, 2024
32989f7
Merge branch 'develop' into logit
zeroRains Apr 28, 2024
9f8d513
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains Apr 29, 2024
8529a00
fix conflict
zeroRains Apr 29, 2024
a5b28a3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains May 2, 2024
f718e0c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains May 10, 2024
af41ecc
fix bug
zeroRains May 16, 2024
0026451
modfiy the test data and change the file name
zeroRains May 16, 2024
c2edfa0
roback
zeroRains May 17, 2024
6283cb5
fix bug
zeroRains May 17, 2024
c55191d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains May 18, 2024
2235976
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zeroRains May 19, 2024
6c2bfe7
support mean_all for dynamic shape
zeroRains May 21, 2024
f954cd6
modify the type
zeroRains May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@
"leaky_relu",
"log_softmax",
"mean",
"mean_all",
"meshgrid",
"one_hot",
"p_norm",
"pow",
"reciprocal",
"relu",
"relu6",
"sigmoid_cross_entropy_with_logits",
"silu",
"swiglu",
"softmax",
Expand Down Expand Up @@ -82,12 +84,14 @@
"leaky_relu",
"log_softmax",
"mean",
"mean_all",
"meshgrid",
"p_norm",
"pow",
"reciprocal",
"relu",
"relu6",
"sigmoid_cross_entropy_with_logits",
"silu",
"swiglu",
"softmax",
Expand Down
67 changes: 67 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,73 @@ Tensor square_decomp(const Tensor& x) {
}
}

template <typename T>
Tensor sigmoid_cross_entropy_with_logits_decomp(
const Tensor& x,
const Tensor& label,
const paddle::optional<Tensor>& pos_weight,
bool normalize,
int ignore_index) {
auto dims = x.shape();
const Tensor zero = full<T>(dims, 0, x.type());
const Tensor one = full<T>(dims, 1, x.type());
Tensor pos_weight_tensor;
if (pos_weight) {
pos_weight_tensor = pos_weight.get();
} else {
pos_weight_tensor = one;
}
auto term1 = where<T>(x > zero, x, zero);
auto term2 = x * label;
auto term3 = log<T>(1 + exp<T>(-abs<T>(x)));
const Tensor tmp_out = term1 - term2 + term3 * pos_weight_tensor;
const Tensor ignore_index_tensor = full<T>(dims, ignore_index, label.type());
auto out = where<T>(label == ignore_index_tensor, zero, tmp_out);
if (normalize) {
// Follow the implementation in
// paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_kernel.cc
const Tensor eps1 = full<T>(dims, 1e-6, x.type());
auto diff = label - ignore_index_tensor;
const Tensor tmp_norm = sum<T>(where<T>(abs<T>(diff) > eps1, one, zero));
// Follow the implementation in
// paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_kernel.cc
const Tensor eps2 = full<T>(empty_shape, 1e-5, x.type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-5同上?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

auto norm = where<T>(tmp_norm > eps2, tmp_norm, eps2);
out = out / norm;
}
return out;
}

template <typename T>
Tensor mean_all_decomp(const Tensor& x) {
auto org_dtype = x.dtype();
auto x_cast = x;
auto x_shape = x.shape();
bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, DataType::FLOAT32);
}

Tensor ans;
if (has_dynamic_shape(x_shape)) {
Tensor x_shape_tensor = shape<T>(x_cast);
Tensor value = get_slice<T>(x_shape_tensor, 0);
for (size_t i = 1; i < x_shape.size(); i++) {
value = value * get_slice<T>(x_shape_tensor, i);
}
value = reshape<T>(value, {});
ans = sum<T>(x_cast) / cast<T>(value, x_cast.dtype());
} else {
ans = sum<T>(x_cast) / x_cast.numel();
}

if (need_cast) {
return cast<T>(ans, org_dtype);
} else {
return ans;
}
}

template <typename T>
Tensor embedding_decomp(const Tensor& x,
const Tensor& weight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class TestSigmoidCrossEntropyWithLogitsOp1(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = 64
num_classes = 20
self.inputs = {
Expand All @@ -60,10 +62,10 @@ def setUp(self):
self.outputs = {'Out': -term1 - term2}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
Expand All @@ -72,6 +74,8 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = 64
num_classes = 20
ignore_index = -1
Expand Down Expand Up @@ -99,10 +103,10 @@ def setUp(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithLogitsOp3(OpTest):
Expand All @@ -111,6 +115,8 @@ class TestSigmoidCrossEntropyWithLogitsOp3(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = 64
num_classes = 20
self.inputs = {
Expand All @@ -133,10 +139,10 @@ def setUp(self):
self.outputs = {'Out': -term1 - term2}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithLogitsOp4(OpTest):
Expand All @@ -145,6 +151,8 @@ class TestSigmoidCrossEntropyWithLogitsOp4(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = 64
num_classes = 20

Expand All @@ -171,7 +179,7 @@ def setUp(self):
self.outputs = {'Out': term1 - term2 + term3}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.0005, check_pir=True)
Expand All @@ -181,6 +189,8 @@ class TestSigmoidCrossEntropyWithNorm(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = 64
num_classes = 20
ignore_index = -1
Expand All @@ -207,10 +217,10 @@ def setUp(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithLogitsOp5(OpTest):
Expand All @@ -219,6 +229,8 @@ class TestSigmoidCrossEntropyWithLogitsOp5(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = [10, 10]
num_classes = 20
self.inputs = {
Expand All @@ -241,16 +253,18 @@ def setUp(self):
self.outputs = {'Out': -term1 - term2}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithNorm2(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = [10, 10]
num_classes = 20
ignore_index = -1
Expand All @@ -277,10 +291,10 @@ def setUp(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithLogitsOp6(OpTest):
Expand All @@ -289,6 +303,8 @@ class TestSigmoidCrossEntropyWithLogitsOp6(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
self.prim_op_type = "comp"
self.public_python_api = loss_wrapper
batch_size = [10, 10]
num_classes = 20
self.inputs = {
Expand All @@ -311,10 +327,10 @@ def setUp(self):
self.outputs = {'Out': -term1 - term2}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestSigmoidCrossEntropyWithLogitsOpError(unittest.TestCase):
Expand Down
85 changes: 85 additions & 0 deletions test/legacy_test/test_binary_cross_entropy_with_logits_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


class TestBinaryCrossEntropyWithLogits(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.x = np.random.randn(300, 1000).astype("float32")
self.y = np.random.randint(0, 2, (300, 1000)).astype("float32")
self.logits = paddle.to_tensor(self.x)
self.labels = paddle.to_tensor(self.y)
self.weight = paddle.to_tensor(
np.random.randn(300, 1000).astype("float32")
)
self.reduction = ["none", "mean", "sum"]
self.pos_weight = paddle.to_tensor(
np.random.randn(1000).astype("float32")
)

def test_binary_cross_entropy_with_logits(self):
for reduction in self.reduction:
dynamic_result = (
paddle.nn.functional.binary_cross_entropy_with_logits(
self.logits,
self.labels,
weight=self.weight,
reduction=reduction,
pos_weight=self.pos_weight,
)
)
paddle.core._set_prim_all_enabled(True)
static_result = paddle.jit.to_static(
paddle.nn.functional.binary_cross_entropy_with_logits,
full_graph=True,
)(
self.logits,
self.labels,
weight=self.weight,
reduction=reduction,
pos_weight=self.pos_weight,
)
paddle.core._set_prim_all_enabled(False)
np.testing.assert_allclose(
dynamic_result.numpy(), static_result.numpy(), rtol=1e-4
)


class TestBinaryCrossEntropyWithLogits1(TestBinaryCrossEntropyWithLogits):
def setUp(self):
super().setUp()
self.weight = None


class TestBinaryCrossEntropyWithLogits2(TestBinaryCrossEntropyWithLogits):
def setUp(self):
super().setUp()
self.pos_weight = None


class TestBinaryCrossEntropyWithLogits3(TestBinaryCrossEntropyWithLogits):
def setUp(self):
super().setUp()
self.weight = None
self.pos_weight = None


if __name__ == "__main__":
unittest.main()
17 changes: 17 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def dropout_net1(x):
return paddle.nn.functional.dropout(x, 0.5)


def mean_all_net1(x):
return paddle._C_ops.mean_all(x)


group_norm1 = paddle.nn.GroupNorm(num_channels=128, num_groups=32)


Expand Down Expand Up @@ -509,5 +513,18 @@ def setUp(self):
self.tol = 1e-6


class TestPrimMeanAll(TestPrimBase):
def setUp(self):
np.random.seed(2023)
paddle.seed(2023)
self.shape_x = [300, 4096]
self.dtype_x = "float32"
self.init_x_shape = [None, 4096]
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
self.net = mean_all_net1
self.necessary_ops = "pd_op.mean_all"
self.enable_cinn = False


if __name__ == "__main__":
unittest.main()