Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use convert_element_type instead of device_put_raw. #6400

Merged
merged 4 commits into from
May 6, 2021

Conversation

pschuh
Copy link
Collaborator

@pschuh pschuh commented Apr 9, 2021

No description provided.

@google-cla google-cla bot added the cla: yes label Apr 9, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 10, 2021

Hi - out of curiosity, what was the impetus for this change? I ask because this kind of change tends to have subtle unintented consequences, so I'd prefer to avoid it unless there is good reason.

@mattjj
Copy link
Collaborator

mattjj commented Apr 10, 2021

Also, this code was written this way for performance, so we'd want to check some benchmarks for any change! #3350 has some basic microbenchmarks for jnp.array. (Now that we have some basic benchmark infrastructure, we should add this kind of stuff! )

@mattjj mattjj self-assigned this Apr 10, 2021
@mattjj mattjj self-requested a review April 10, 2021 03:37
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find, Parker! (See jax-dev chat thread for more discussion.)

@mattjj
Copy link
Collaborator

mattjj commented Apr 10, 2021

I ran the simple timeit jnp.array([0] * int(1e6)) benchmark from #3350 and there was no change (214 ms ± 1.95 on master, 213 ms ± 1.88 on this branch).

@mattjj mattjj added the pull ready Ready for copybara import and testing label Apr 10, 2021
pschuh added a commit to pschuh/flax that referenced this pull request Apr 20, 2021
pschuh added a commit to pschuh/flax that referenced this pull request Apr 20, 2021
pschuh added a commit to pschuh/trax that referenced this pull request Apr 20, 2021
copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 20, 2021
copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 20, 2021
copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 20, 2021
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Apr 22, 2021
pschuh and others added 2 commits April 22, 2021 12:46
also add a test for not performing H2D transfers while tracing jnp.array
@mattjj
Copy link
Collaborator

mattjj commented Apr 22, 2021

@jakevdp the main motivation is to avoid host-to-device (H2D) transfers during jit tracing, very much analogous to #6014. See the jax-dev chat thread for more discussion.

copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 25, 2021
With jax-ml/jax#6400,
jnp.array delays resolving the axis name until after tracing is complete.

PiperOrigin-RevId: 370356687
copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 25, 2021
With jax-ml/jax#6400,
jnp.array delays resolving the axis name until after tracing is complete.

PiperOrigin-RevId: 370356687
copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 28, 2021
With jax-ml/jax#6400,
jnp.array delays resolving the axis name until after tracing is complete.

PiperOrigin-RevId: 370356687
copybara-service bot pushed a commit to google/trax that referenced this pull request Apr 28, 2021
With jax-ml/jax#6400,
jnp.array delays resolving the axis name until after tracing is complete.

PiperOrigin-RevId: 370798828
tensorflow-copybara pushed a commit to google-research/google-research that referenced this pull request May 6, 2021
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels May 6, 2021
@google-cla
Copy link

google-cla bot commented May 6, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added cla: no and removed cla: yes labels May 6, 2021
@mattjj mattjj added cla: yes and removed cla: no labels May 6, 2021
@mattjj
Copy link
Collaborator

mattjj commented May 6, 2021

@googlebot I consent.

@google-cla
Copy link

google-cla bot commented May 6, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added cla: no and removed cla: yes labels May 6, 2021
@mattjj mattjj added cla: yes and removed cla: no labels May 6, 2021
@copybara-service copybara-service bot merged commit d0aa875 into jax-ml:master May 6, 2021
copybara-service bot pushed a commit that referenced this pull request Sep 21, 2021
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 395998020
copybara-service bot pushed a commit that referenced this pull request Sep 21, 2021
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 395998020
copybara-service bot pushed a commit that referenced this pull request Sep 21, 2021
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
SaturdayGenfo pushed a commit to SaturdayGenfo/jax that referenced this pull request Sep 28, 2021
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants