Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for broadcast multiply along batch dimension (W) #5769

Closed
Tracked by #6445
kpaigwar opened this issue Feb 28, 2024 · 10 comments
Closed
Tracked by #6445

Add support for broadcast multiply along batch dimension (W) #5769

kpaigwar opened this issue Feb 28, 2024 · 10 comments

Comments

@kpaigwar
Copy link
Contributor

Requirement

Need support for an op which can perform element-wise multiply and broadcast along batch dim.
For example, (32, 1, 32, 1024)*(1, 1, 32, 1024)

@kpaigwar
Copy link
Contributor Author

fyi @apalaguha @esmalTT

@kpaigwar
Copy link
Contributor Author

kpaigwar commented Feb 29, 2024

This op is crtical for us when doing element-wise multiplication of conv1D weights (batch independent) with inputs (batch dependent). The workaround to manually broadcast the conv1D weights makes the implementation harder and non-performant.

@bharane-abb
Copy link

bharane-abb commented Mar 12, 2024

Hi @kpaigwar @jliangTT
We have come up with two ideas for multiplication without using broadcasting:

  1. Using the repeat function, we will repeat the smaller tensor into the required shape and then proceed with multiplication.
  2. Using the unpad function, we can unpad the larger tensor into a smaller tensor and multiply it with the smaller tensor. However, this process is time-consuming as it involves many unpad, multiply, and concat operations.

@kpaigwar
Copy link
Contributor Author

@bharane-ab first approach would be ideal

@umadevimcw
Copy link
Contributor

umadevimcw commented Mar 15, 2024

@kpaigwar Regarding the above why are we just considering mul operation alone? What about other binary, operations like squared_diff, add, sub, etc.. ? If we make changes in commonplace then it would be applied to all.

Tried to update the eltwise_binary_op.hpp to apply changes to all ops in binary. But getting error shown in below image. Suggested flag (in the image not helping). Not able to use repeat in the mentioned file

Screenshot 2024-03-15 at 1 31 37 PM

@kpaigwar Do you want me to go with separate op like batch_mul or some thing? Any suggestions?

@umadevimcw
Copy link
Contributor

Tried the implementation in composite as a separate batch mul op #6442

@tt-aho
Copy link
Contributor

tt-aho commented Mar 15, 2024

If you are targeting performance, are you planning to run this model/op sharded? Or exactly as specified in this issue? The optimizations/support for sharding are different than interleaved so want to make sure we're targeting/optimizing the right thing. Similar question for #6361

@kpaigwar
Copy link
Contributor Author

kpaigwar commented Mar 15, 2024

Since this issue is also related to repeat #6361, I will drop the priority of this issue and would address P0 issue first #6361

@umadevimcw
Copy link
Contributor

@kpaigwar Please find the updated PR #6587 for batch mul support. It doesn't involve new op. Also changed the code in common place hence it is working for other binary op as well

@umadevimcw
Copy link
Contributor

Support Merged to Main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants