Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.174】 Migrate paddle.randint_like into pir #58953

Merged
merged 31 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.base.framework import _current_expected_place
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import Variable
from paddle.framework import (
in_dynamic_mode,
Expand Down Expand Up @@ -987,16 +988,27 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
if dtype is None:
dtype = core.VarDesc.VarType.INT64
if in_pir_mode():
from paddle.base.libpaddle import DataType

dtype = DataType.INT64
elif not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_or_pir_mode():
if in_dynamic_mode():
shape = paddle.utils.convert_shape_to_list(shape)
enkilee marked this conversation as resolved.
Show resolved Hide resolved
place = _current_expected_place()
return _C_ops.randint(low, high, shape, dtype, place)
return _C_ops.randint(
low, high, shape, dtype, _current_expected_place()
)
elif in_pir_mode():
check_type(
shape, 'shape', (list, tuple, paddle.pir.OpResult), 'randint'
)
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
if paddle.utils._contain_var(shape):
shape = paddle.utils.get_int_tensor_list(
shape, _current_expected_place()
)
return _C_ops.randint(
low, high, shape, dtype, _current_expected_place()
)
else:
check_shape(shape, 'randint')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
Expand Down Expand Up @@ -1172,8 +1184,9 @@ def randint_like(x, low=0, high=None, dtype=None, name=None):
low = 0
if dtype is None:
dtype = x.dtype
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
else:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
shape = paddle.shape(x)

if low >= high:
Expand All @@ -1182,20 +1195,41 @@ def randint_like(x, low=0, high=None, dtype=None, name=None):
f"high = {high}"
)

if in_dynamic_mode():
shape = paddle.utils.convert_shape_to_list(shape)
out = _legacy_C_ops.randint(
'shape',
shape,
'low',
low,
'high',
high,
'seed',
0,
'dtype',
core.VarDesc.VarType.INT64,
)
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
shape = paddle.utils.convert_shape_to_list(shape)
out = _legacy_C_ops.randint(
'shape',
shape,
'low',
low,
'high',
high,
'seed',
0,
'dtype',
core.VarDesc.VarType.INT64,
)
else:
check_type(
shape,
'shape',
(list, tuple, paddle.pir.OpResult),
'randint_like',
)
check_dtype(
dtype,
'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'randint_like',
)
if paddle.utils._contain_var(shape):
shape = paddle.utils.get_int_tensor_list(
shape, _current_expected_place()
)
out = _C_ops.randint(
low, high, shape, DataType.INT64, _current_expected_place()
)
out = paddle.cast(out, dtype)
return out
else:
Expand Down
69 changes: 56 additions & 13 deletions test/legacy_test/test_randint_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

import paddle
from paddle.static import Program, program_guard
from paddle.pir_utils import test_with_pir_api


# Test python API
Expand All @@ -37,9 +37,12 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_api(self):
paddle.enable_static()
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# results are from [-100, 100).
x_bool = paddle.static.data(
name="x_bool", shape=[10, 12], dtype="bool"
Expand All @@ -50,11 +53,18 @@ def test_static_api(self):
paddle.randint_like(x_bool, low=-10, high=10, dtype=dtype)
for dtype in self.dtype
]
outs1 = exe.run(feed={'x_bool': self.x_bool}, fetch_list=outlist1)
outs1 = exe.run(feed={'x_bool': self.x_bool}, fetch_list=[outlist1])
for out, dtype in zip(outs1, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -10) & (out <= 10)).all(), True)
with program_guard(Program(), Program()):
paddle.disable_static()

@test_with_pir_api
def test_static_api_with_int32(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_int32 = paddle.static.data(
name="x_int32", shape=[10, 12], dtype="int32"
)
Expand All @@ -64,12 +74,22 @@ def test_static_api(self):
paddle.randint_like(x_int32, low=-5, high=10, dtype=dtype)
for dtype in self.dtype
]
outs2 = exe.run(feed={'x_int32': self.x_int32}, fetch_list=outlist2)
for out, dtype in zip(outs2, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -5) & (out <= 10)).all(), True)
outs2 = exe.run(
paddle.static.default_main_program(),
feed={'x_int32': np.zeros((10, 12)).astype(np.int32)},
fetch_list=[outlist2],
)
for out2, dtype in zip(outs2, self.dtype):
self.assertTrue(out2.dtype, np.dtype(dtype))
self.assertTrue(((out2 >= -5) & (out2 <= 10)).all(), True)
paddle.disable_static()

with program_guard(Program(), Program()):
@test_with_pir_api
def test_static_api_with_int64(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_int64 = paddle.static.data(
name="x_int64", shape=[10, 12], dtype="int64"
)
Expand All @@ -83,8 +103,15 @@ def test_static_api(self):
for out, dtype in zip(outs3, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -100) & (out <= 100)).all(), True)
paddle.disable_static()

@test_with_pir_api
def test_static_api_with_fp16(self):
paddle.enable_static()
if paddle.is_compiled_with_cuda():
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_float16 = paddle.static.data(
name="x_float16", shape=[10, 12], dtype="float16"
)
Expand All @@ -100,8 +127,14 @@ def test_static_api(self):
for out, dtype in zip(outs4, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -3) & (out <= 25)).all(), True)
paddle.disable_static()

with program_guard(Program(), Program()):
@test_with_pir_api
def test_static_api_with_float32(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_float32 = paddle.static.data(
name="x_float32", shape=[10, 12], dtype="float32"
)
Expand All @@ -117,8 +150,14 @@ def test_static_api(self):
for out, dtype in zip(outs5, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -25) & (out <= 25)).all(), True)
paddle.disable_static()

with program_guard(Program(), Program()):
@test_with_pir_api
def test_static_api_with_float64(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_float64 = paddle.static.data(
name="x_float64", shape=[10, 12], dtype="float64"
)
Expand All @@ -134,6 +173,7 @@ def test_static_api(self):
for out, dtype in zip(outs6, self.dtype):
self.assertTrue(out.dtype, dtype)
self.assertTrue(((out >= -16) & (out <= 16)).all(), True)
paddle.disable_static()

def test_dygraph_api(self):
paddle.disable_static(self.place)
Expand Down Expand Up @@ -169,9 +209,12 @@ def test_dygraph_api(self):
)
paddle.enable_static()

@test_with_pir_api
def test_errors(self):
enkilee marked this conversation as resolved.
Show resolved Hide resolved
paddle.enable_static()
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x_bool = paddle.static.data(
name="x_bool", shape=[10, 12], dtype="bool"
)
Expand Down
41 changes: 33 additions & 8 deletions test/legacy_test/test_zero_dim_tensor.py
enkilee marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import paddle
import paddle.nn.functional as F
from paddle.pir_utils import test_with_pir_api

unary_api_list = [
paddle.nn.functional.elu,
Expand Down Expand Up @@ -5808,17 +5809,41 @@ def test_randn(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))

def test_randint_and_randint_like(self):
out1 = paddle.randint(-10, 10, [])
out2 = paddle.randint_like(out1, -10, 10)
out3 = paddle.randint(-10, 10, self.shape)
@test_with_pir_api
def test_randint(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
out1 = paddle.randint(-10, 10, [])

shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
paddle.full([], 4, 'int32'),
]
out2 = paddle.randint(-10, 10, shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)

self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))

@test_with_pir_api
def test_randint_like(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
out1 = paddle.rand([])
out2 = paddle.randint_like(out1, -10, 10)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))

def test_standard_normal(self):
out1 = paddle.standard_normal([])
Expand Down