-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcast the propagation_expr
for vector mode AD
#93
Conversation
fe8d1b7
to
65b2ce2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will review tomorrow.
# avoiding the extra `+` operation, it is potentially | ||
# expensive for vector mode AD | ||
sumed_∂_mul_Δs = if n∂s > 1 | ||
:(@. +($(∂_mul_Δs...))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use the macro?
.+($(...))
should be identical
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And it is more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it less readable.
And this is kind of why -- it's not as obvious what is being broadcast.
E.g. This case where we are also broadcasting the stuff from the other line.
I'ld just make both explicit broadcasts.
--
I know the DiffEq code base uses this a ton.
I have never used it except during the transition from implict broadcast back in 0.5ish
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:(@. +($(∂_mul_Δs...))) | |
:(.+($(∂_mul_Δs...))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, now I remember. Using explicit broadcast broke tests on older versions of Julia.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks right to me when I do macroexpand
but indeed it doesn't work in 1.0 for me locally either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:(@. +($(∂_mul_Δs...))) | |
# we use `@.` to broadcast the . and + | |
# Note: we don't want to do broadcasting to occur if only 1 multiply (no +), | |
# because some arrays overload multiply with scalar, and those that don't fall back to broadcasting anyway. | |
# Note also: due to bugs in Julia 1.0 can't use `.+` or `.*` inside expression literals anyway | |
:(@. +($(∂_mul_Δs...))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This reverts commit b7d1da2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets leave a comment so we remember why we need @.
, then this can be merged.
See my suggestion
…ChainRulesCore (#93) * =Remove uses of add and mul from tests because they have been removed from ChainRulesCore * remove ref to Cassette. Slightly improve talk of ChainRulesCore * Update getting_started.md * Update Project.toml
I also added
@muladd
to optimizepropagation_expr
. It rewritesa*b + c*d
tomuladd(a, b, c*d)
.