Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode #23798

Merged
merged 1 commit into from
Sep 20, 2024

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Sep 20, 2024

[Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode

This is a follow-up of #23747, which enables Pallas OpsTest in 64-bit mode.

In order to enable Pallas OpsExtraTest in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

  1. Most of the modifications are just changing jnp.int32 to intx and jnp.float32 to floatx, which uses the same approach as the previous PR Enable Pallas ops_test on GPU in 64-bit mode. #23747. intx and floatx are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
  2. For the test test_array_reduce, the original code uses a simple approach to determine out_dtype from dtype, which no longer works in 64-bit mode. Therefore, I modified the code to deduct out_dtype by executing the operation on a single element first.
  3. For the test test_masked_load_store, the idx variable is expected to be an int32 array, which is calculated based on pl.program_id() and block_size. In 64-bit mode, the computation will give out an int64 array instead. Since pl.program_id() always returns an int32 result, I modified the computation to produce int32 result. I also modified the pl.program_id() docstring to document the behaviour that pl.program_id() always returns an int32 result.

This is a follow-up of #23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR #23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 677007613
@copybara-service copybara-service bot merged commit d63afd8 into main Sep 20, 2024
@copybara-service copybara-service bot deleted the test_676838304 branch September 20, 2024 23:18
copybara-service bot pushed a commit that referenced this pull request Sep 22, 2024
…t mode

These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by #23798

PiperOrigin-RevId: 677483206
copybara-service bot pushed a commit that referenced this pull request Sep 22, 2024
…t mode

These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by #23798

PiperOrigin-RevId: 677483206
copybara-service bot pushed a commit that referenced this pull request Sep 23, 2024
…t mode

These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by #23798

PiperOrigin-RevId: 677483206
copybara-service bot pushed a commit that referenced this pull request Sep 23, 2024
…t mode

These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by #23798

PiperOrigin-RevId: 677583606
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant