Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize execution for ops that have multiple output in eager mode #7680

Merged
merged 2 commits into from
Jul 16, 2024

Conversation

JackCaoG
Copy link
Collaborator

In eager mode the execution happens when we create an XLATensor with IR, we will use the IR as the root to build/execute the graph.

This is mostly fine but for ops that has multiple outputs(like native_batch_norm), most of the outputs share a good amounts of common HLOs. It will be much faster to execute all of them in a single graph. The eager mode in PyTorch/XLA can't really execute HLO one by one, so the goal is to execute once(ideally) for each pytorch op.

The change in this pr will

  1. delay the eager execution for some ops when they creating new XLAtensor with IRs
  2. execute the HLO for all XLAtensors after they are created.

I will take another round to check I didn't mess up anything but would appreciate if someone can look closely at my change inside tensor_method.cpp.

@JackCaoG JackCaoG added the eager label Jul 12, 2024
@JackCaoG
Copy link
Collaborator Author

I also intentionally didn't handle the collectives. Collective will return a all_reduce token which we actually don't want to execute in eager case. I will handle that in a separate pr.

@aws-rhsoln
Copy link
Contributor

Curious how much perf boost do we expect when we fuse them into a single graph?

@JackCaoG
Copy link
Collaborator Author

JackCaoG commented Jul 15, 2024

Curious how much perf boost do we expect when we fuse them into a single graph?

for a test code

torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
m = nn.BatchNorm2d(16).to(device)
m.train()
input = torch.randn(16, 16, 1024, 1024, device=device)

start = time.time()
for _ in range(20):
  input = m(input)
xm.wait_device_ops()
end = time.time()
duration = end - start
print(f"total time = {duration}")

with my change total time = 0.46190381050109863, without this change total time = 14.28174352645874. I actually don;t know why it is 28x faster, but I did verified that in HLO without my change BatchNorm2d will compute the result one by one.

@JackCaoG JackCaoG marked this pull request as ready for review July 15, 2024 18:37
@JackCaoG
Copy link
Collaborator Author

@alanwaketan @wonjoolee95 This one is ready for review.

@JackCaoG JackCaoG merged commit b2c7f65 into master Jul 16, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants