Skip to content

Commit

Permalink
【PIR API adaptor No.166】 Migrate paddle.polygamma into pir (PaddlePad…
Browse files Browse the repository at this point in the history
  • Loading branch information
PommesPeter authored and zeroRains committed Nov 8, 2023
1 parent b1adc64 commit e302c60
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6828,7 +6828,7 @@ def polygamma(x, n, name=None):
if n == 0:
return digamma(x)
else:
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.polygamma(x, n)
else:
check_variable_and_dtype(
Expand Down
5 changes: 4 additions & 1 deletion test/legacy_test/test_polygamma_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.random.seed(100)
paddle.seed(100)
Expand Down Expand Up @@ -64,6 +65,7 @@ def setUp(self):
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

@test_with_pir_api
def test_api_static(self):
def run(place):
paddle.enable_static()
Expand Down Expand Up @@ -197,7 +199,7 @@ def init_config(self):
self.target = ref_polygamma(self.inputs['x'], self.order)

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
Expand All @@ -206,6 +208,7 @@ def test_check_grad(self):
user_defined_grads=[
ref_polygamma_grad(self.case, 1 / self.case.size, self.order)
],
check_pir=True,
)


Expand Down

0 comments on commit e302c60

Please sign in to comment.