Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjust lax.convert_element_type bind to avoid H2D transfers during tracing #6014

Merged
merged 1 commit into from
Mar 19, 2021

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Mar 10, 2021

@jekbradbury and others noticed that we were performing H2D transfers while tracing, e.g. while tracing random.split. @zhangqiaorjc and I tracked it down to how lax.convert_element_type would sometimes call a device_put function directly. The fix was to bind the convert_element_type primitive in more cases. (See also the discussion thread on #5998 about the fix, which was moved from that PR to this one.)

A downside of this fix is that it led to more convert_element_type primitive applications in jaxprs, especially for literals. That added visual clutter. This PR fixes that issue by tweaking _inline_literals in partial_eval.py to constant-fold dtype conversion of literals at jaxpr formation time. (To do this, I moved convert_element_type_p to core.py so that partial_eval.py knows about it.)

(This PR currently contains two commits from #5998, but once that's merged into master they'll disappear; only the diff from the last commit is relevant to this PR.)

fixes #5308

@mattjj mattjj requested a review from zhangqiaorjc March 10, 2021 21:11
@google-cla google-cla bot added the cla: yes label Mar 10, 2021
@mattjj mattjj force-pushed the convert-element-type-bind branch 2 times, most recently from 848ac7e to c3e08e6 Compare March 10, 2021 21:15
@mattjj mattjj force-pushed the convert-element-type-bind branch 3 times, most recently from 2d05aa2 to 0803111 Compare March 17, 2021 01:38
@mattjj mattjj added the pull ready Ready for copybara import and testing label Mar 17, 2021
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
@mattjj mattjj force-pushed the convert-element-type-bind branch from 0803111 to bf15ba5 Compare March 19, 2021 20:42
@copybara-service copybara-service bot merged commit 64e851f into master Mar 19, 2021
@mattjj mattjj deleted the convert-element-type-bind branch March 19, 2021 21:55
copybara-service bot pushed a commit that referenced this pull request Mar 21, 2021
copybara-service bot pushed a commit that referenced this pull request Mar 21, 2021
copybara-service bot pushed a commit that referenced this pull request Mar 21, 2021
copybara-service bot pushed a commit that referenced this pull request Mar 21, 2021
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Apr 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NumPy constants device_put multiple times during tracing
2 participants