Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcast in the frule for * #148

Merged
merged 3 commits into from
Jan 13, 2020
Merged

Add broadcast in the frule for * #148

merged 3 commits into from
Jan 13, 2020

Conversation

YingboMa
Copy link
Member

@YingboMa YingboMa commented Jan 13, 2020

@@ -103,7 +103,7 @@
# product rule requires special care for arguments where `mul` is non-commutative

function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
return x * y, Δx * y + x * Δy
return x * y, @. muladd(Δx, y, x * Δy)
Copy link
Member

@oxinabox oxinabox Jan 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find @. surprisingly hard to read.
Took me 10 seconds to realize the muladd was getting broadcast also 😂

So this is the same as:

Suggested change
return x * y, @. muladd(Δx, y, x * Δy)
return x * y, muladd.(Δx, y, x .* Δy) # optimized version of `Δx .* y .+ x .* Δy

For interest i did some benchmarking:
Writing out the broadcast seems consistently slightly faster.
Is pretty small though.
But I think for the Number case which this always is, its just going to hit the
muladd(x,y,z) = x*y+z definition.

julia> @btime $dx .* $y .+ $x .* $dy;
  49.495 μs (2 allocations: 781.33 KiB)

julia> @btime muladd.($dx, $y, $x .* $dy);
  49.468 μs (2 allocations: 781.33 KiB)

So I think might as well do the clearer:

Suggested change
return x * y, @. muladd(Δx, y, x * Δy)
return x * y, (Δx .* y .+ x .* Δy)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is potentially more accurate on machines with FMA instructions since there are only two roundings, one in muladd/fma the other in *.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷‍♂

They come out identical on my machine (which has fma) over 100_000 values in Δy and Δx

I think the accuracy only comes into play if y was a matrix.
Which it is not permitted to be here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having muladd certainly won't hurt. Also, people could define their own number type, and overload muladd.

@YingboMa YingboMa merged commit d7be84a into master Jan 13, 2020
@YingboMa YingboMa deleted the myb/vector_mode branch January 13, 2020 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants