Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPMD: broadcast of replicated tensor is mark sharded #17913

Open
fhaolinaws opened this issue Oct 4, 2024 · 0 comments
Open

SPMD: broadcast of replicated tensor is mark sharded #17913

fhaolinaws opened this issue Oct 4, 2024 · 0 comments

Comments

@fhaolinaws
Copy link

Hello,
We saw the issue that a broadcast tensor from a single-dimension parameter is marked sharded by XLA sharding propagator. This sharded tensor, while doing computation with other tensor which has incompatible sharding spec, will incur additional communications. However, these communications can be avoided since the value of the broadcast tensor is just copies of the single-dimension parameter.

HLO example:

param.1 = bf16[] parameter(27), sharding={replicated}, metadata={op_type="xla__device_data" op_name="xla__device_data" }
broadcast.356 = bf16[256,2,256]{2,1,0} broadcast(param.1), dimensions={}, metadata={op_type="aten__lerp" op_name="aten__lerp.65/aten__lerp"}
all-to-all.1 = bf16[256,2,256]{2,1,0} all-to-all(broadcast.356), channel_id=90, replica_groups={{0,1},{8,9},{16,17},{24,25},{2,3},{10,11},{18,19},{26,27},{4,5},{12,13},{20,21},{28,29},{6,7},{14,15},{22,23},{30,31}}, dimensions={1}, metadata={op_type="aten__lerp" op_name="aten__lerp.65/aten__lerp"}
transpose.50 = bf16[2,256,256]{2,0,1} transpose(all-to-all.1), dimensions={1,0,2}, metadata={op_type="aten__lerp" op_name="aten__lerp.65/aten__lerp"}
reshape.551 = bf16[512,256]{1,0} reshape(transpose.50), metadata={op_type="aten__lerp" op_name="aten__lerp.65/aten__lerp"}
collective-permute.1 = bf16[512,256]{1,0} collective-permute(reshape.551), channel_id=91, source_target_pairs={{0,0},{1,1},{8,2},{9,3},{16,4},{17,5},{24,6},{25,7},{2,8},{3,9},{10,10},{11,11},{18,12},{19,13},{26,14},{27,15},{4,16},{5,17},{12,18},{13,19},{20,20},{21,21},{28,22},{29,23},{6,24},{7,25},{14,26},{15,27},{22,28},{23,29},{30,30},{31,31}}, metadata={op_type="aten__lerp" op_name="aten__lerp.65/aten__lerp"}

This happens commonly when the broadcast tensor is reused for computations with tensors of different sharding spec, eg. optimizer parameters doing point-wise computation with weight update.

We want to propose to change the sharding spec of broadcast tensors to replicate if the input of the broadcast is replicated. And this will be done in the SPMD partitioner pass after the shape partitioning is finished.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant