[Pallas GPU] Enable Pallas OpsExtraTest
in 64-bit mode
#23798
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[Pallas GPU] Enable Pallas
OpsExtraTest
in 64-bit modeThis is a follow-up of #23747, which enables Pallas
OpsTest
in 64-bit mode.In order to enable Pallas
OpsExtraTest
in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:jnp.int32
tointx
andjnp.float32
tofloatx
, which uses the same approach as the previous PR Enable Pallasops_test
on GPU in 64-bit mode. #23747.intx
andfloatx
are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.test_array_reduce
, the original code uses a simple approach to determineout_dtype
fromdtype
, which no longer works in 64-bit mode. Therefore, I modified the code to deductout_dtype
by executing the operation on a single element first.test_masked_load_store
, theidx
variable is expected to be anint32
array, which is calculated based onpl.program_id()
andblock_size
. In 64-bit mode, the computation will give out anint64
array instead. Sincepl.program_id()
always returns anint32
result, I modified the computation to produceint32
result. I also modified thepl.program_id()
docstring to document the behaviour thatpl.program_id()
always returns anint32
result.