Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast the propagation_expr for vector mode AD #93

Merged
merged 7 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.5.1"
version = "0.5.2"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

[compat]
julia = "^1.0"
MuladdMacro = "0.2.1"
YingboMa marked this conversation as resolved.
Show resolved Hide resolved

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
21 changes: 19 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# These are some macros (and supporting functions) to make it easier to define rules.
using MuladdMacro: @muladd

"""
@scalar_rule(f(x₁, x₂, ...),
Expand Down Expand Up @@ -208,9 +209,25 @@ end
function propagation_expr(Δs, ∂s)
# This is basically Δs ⋅ ∂s
∂s = map(esc, ∂s)
n∂s = length(∂s)

∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s))
return :(+($(∂_mul_Δs...)))
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
# literals.
∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
YingboMa marked this conversation as resolved.
Show resolved Hide resolved

# Avoiding the extra `+` operation, it is potentially expensive for vector
# mode AD.
sumed_∂_mul_Δs = if n∂s > 1
# we use `@.` to broadcast `*` and `+`
:(@. +($(∂_mul_Δs...)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use the macro?
.+($(...))
should be identical

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it is more readable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it less readable.
And this is kind of why -- it's not as obvious what is being broadcast.
E.g. This case where we are also broadcasting the stuff from the other line.

I'ld just make both explicit broadcasts.

--

I know the DiffEq code base uses this a ton.
I have never used it except during the transition from implict broadcast back in 0.5ish

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:(@. +($(∂_mul_Δs...)))
:(.+($(∂_mul_Δs...)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now I remember. Using explicit broadcast broke tests on older versions of Julia.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks right to me when I do macroexpand but indeed it doesn't work in 1.0 for me locally either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:(@. +($(∂_mul_Δs...)))
# we use `@.` to broadcast the . and +
# Note: we don't want to do broadcasting to occur if only 1 multiply (no +),
# because some arrays overload multiply with scalar, and those that don't fall back to broadcasting anyway.
# Note also: due to bugs in Julia 1.0 can't use `.+` or `.*` inside expression literals anyway
:(@. +($(∂_mul_Δs...)))

else
# Note: we don't want to do broadcasting with only 1 multiply (no `+`),
# because some arrays overload multiply with scalar. Avoiding
# broadcasting saves compilation time.
∂_mul_Δs[1]
end

return :(@muladd $sumed_∂_mul_Δs)
end

"""
Expand Down