Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Metal unary and binary for general case #1431

Merged
merged 4 commits into from
Sep 25, 2024
Merged

Conversation

awni
Copy link
Member

@awni awni commented Sep 23, 2024

  • Add work per thread specialization for binary / binary_two
  • Add dimension collapsing for general unary op

Benchmarks:

Timings on M1 Max

Bench Pre Post
Unary 10.145 6.466
Binary 18.701 10.256

Benchmark code:

def unary(x):
    return mx.abs(x)

x = x[..., ::2]
x = x.transpose(1, 0, 3, 2, 4)

timeit(unary, x)

def binary(x, y):
    return x + y

D = 48
x = mx.random.uniform(shape=(D, D, D, D, D))
y = mx.random.uniform(shape=(D, D, D, D, D))
y = y.transpose(2, 1, 3, 0, 4)

timeit(binary, x, y)

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Just to be clear, the unary speedup comes from using a smaller elem to loc? There is no extra work per thread happening. Did that not help or you didn't want to bloat the PR?

@awni
Copy link
Member Author

awni commented Sep 24, 2024

Just to be clear, the unary speedup comes from using a smaller elem to loc?

Yes, just from the dimension collapsing. I didn't add the work-per-thread specialization just because it's already somewhat rare to have the general unary op (you have to do a slice with a gap to get it) and then it's even more rare to get routed to the version which uses > 3d after dimension collapsing.

I can check to see the impact on the binary size / speedup to see if it's worth adding.

@angeloskath
Copy link
Member

Agreed, the next optimization along those lines would be to check if adding the work per thread in the fused ops has such a dramatic effect as well. Perhaps it matters way less in that case as we almost never call elem_to_loc ...

@awni
Copy link
Member Author

awni commented Sep 25, 2024

For the following benchmark work_per_thread for unary speeds things up by a factor of 2. On my M1 Max goes from 40 to 20 ms.

In terms of binary size it's an increase of 700,000 bytes. So it's not small... but it's not big. For the JIT it doesn't really matter. I'm tempted to include the optimization.. but I could go either way. I added the diff here lmk if you have thoughts on it.

D = 32
x = mx.random.uniform(shape=(D, D, D, D, D, D))
x = x[..., ::2]
x = x.transpose(1, 0, 3, 2, 5, 4)

@angeloskath
Copy link
Member

I would merge it :-)

@awni awni merged commit 4f9f9eb into main Sep 25, 2024
4 checks passed
@awni awni deleted the faster_unary_binary branch September 25, 2024 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants