Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] support square op backward in prim pir #64381

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'sin_grad',
'cos_grad',
'tanh_grad',
'square_grad',
]

# prim op with two inputs and one output, with no attribute
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,14 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void square_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
Tensor x_grad_tmp = 2 * x * out_grad;
set_output<T>(x_grad_tmp, x_grad);
}
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
Expand Down
30 changes: 29 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4297,6 +4297,7 @@ def test_check_grad(self):
'Out',
max_relative_error=0.007,
check_pir=True,
check_prim_pir=True,
check_pir_onednn=self.check_pir_onednn,
)

Expand All @@ -4315,6 +4316,17 @@ def init_dtype(self):
def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'],
'Out',
max_relative_error=0.007,
check_pir=True,
check_pir_onednn=self.check_pir_onednn,
)


class TestSquare_Complex128(TestSquare):
def init_dtype(self):
Expand All @@ -4323,6 +4335,17 @@ def init_dtype(self):
def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'],
'Out',
max_relative_error=0.007,
check_pir=True,
check_pir_onednn=self.check_pir_onednn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么没有check_prim_pir呢?

)


class TestSquare_ZeroDim(TestSquare):
def init_shape(self):
Expand Down Expand Up @@ -4365,7 +4388,12 @@ def test_check_output(self):
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', numeric_grad_delta=0.5, check_pir=True
place,
['X'],
'Out',
numeric_grad_delta=0.5,
check_pir=True,
check_prim_pir=True,
)


Expand Down