Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly lower add and mul #6731

Merged
merged 5 commits into from
Mar 13, 2024
Merged

Properly lower add and mul #6731

merged 5 commits into from
Mar 13, 2024

Conversation

wonjoolee95
Copy link
Collaborator

Partly fixes #6589

@wonjoolee95 wonjoolee95 requested a review from bhavya01 March 13, 2024 01:00
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/lower-ops-properly branch from 9d14293 to 997c79a Compare March 13, 2024 01:06
torch_xla/csrc/elementwise.h Outdated Show resolved Hide resolved
torch_xla/csrc/elementwise.h Outdated Show resolved Hide resolved
@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented Mar 13, 2024

I noticed that one of the SPMD unit tests that checks HLOs failed previously and seems like the HLOs are slightly changed due to this PR because it updates the lowering logic. In this simple SPMD unit test, the HLO without this change is:

ENTRY %IrToHlo.14 (p0.4: f32[1,128], p1.5: f32[1,128]) -> (f32[1,128]) {
  %p1.5 = f32[1,128]{1,0} parameter(1), sharding={replicated}
  %p0.4 = f32[1,128]{1,0} parameter(0), sharding={replicated}
  %constant.3 = f32[] constant(1)
  %broadcast.6 = f32[1,128]{1,0} broadcast(f32[] %constant.3), dimensions={}
  %multiply.7 = f32[1,128]{1,0} multiply(f32[1,128]{1,0} %p0.4, f32[1,128]{1,0} %broadcast.6)
  %add.8 = f32[1,128]{1,0} add(f32[1,128]{1,0} %p1.5, f32[1,128]{1,0} %multiply.7)
  %custom-call.9 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.8), custom_call_target="Sharding", sharding={replicated}
  %constant.2 = f32[] constant(0)
  %constant.1 = f32[] constant(1)
  %multiply.10 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.11 = f32[1,128]{1,0} broadcast(f32[] %multiply.10), dimensions={}
  %add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)
  ROOT %tuple.13 = (f32[1,128]{1,0}) tuple(f32[1,128]{1,0} %add.12)
}

And the updated HLO looks like:

ENTRY %IrToHlo.14 (p0.5: f32[1,128], p1.8: f32[1,128]) -> (f32[1,128]) {
  %p1.8 = f32[1,128]{1,0} parameter(1), sharding={replicated}
  %p0.5 = f32[1,128]{1,0} parameter(0), sharding={replicated}
  %constant.4 = f32[] constant(1)
  %broadcast.6 = f32[1,128]{1,0} broadcast(f32[] %constant.4), dimensions={}
  %multiply.7 = f32[1,128]{1,0} multiply(f32[1,128]{1,0} %p0.5, f32[1,128]{1,0} %broadcast.6)
  %add.9 = f32[1,128]{1,0} add(f32[1,128]{1,0} %p1.8, f32[1,128]{1,0} %multiply.7)
  %custom-call.10 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.9), custom_call_target="Sharding", sharding={replicated}
  %constant.2 = f32[] constant(0)
  %constant.1 = f32[] constant(1)
  %multiply.3 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.11 = f32[1,128]{1,0} broadcast(f32[] %multiply.3), dimensions={}
  %add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)
  ROOT %tuple.13 = (f32[1,128]{1,0}) tuple(f32[1,128]{1,0} %add.12)
}

The contents of the HLOs are the same, the only things that differ is the suffix numbers, ex %custom-call.9 -> %custom-call.10.

Synced offline with @yeounoh, this is mostly fine as it's just the numbering changed in the HLO.

@bhavya01 bhavya01 self-requested a review March 13, 2024 20:27
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/lower-ops-properly branch from a24d373 to 1306408 Compare March 13, 2024 21:52
@wonjoolee95
Copy link
Collaborator Author

Previously CI was all green, the most recent commit just rebases with master. I'll merge this now to make the rc1 branch cut.

@wonjoolee95 wonjoolee95 merged commit fe3f23c into master Mar 13, 2024
2 of 3 checks passed
lsy323 pushed a commit that referenced this pull request Mar 13, 2024
lsy323 added a commit that referenced this pull request Mar 13, 2024
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update IR-level lowerings to proper lowerings
4 participants