From 0cfbf78229beec60e398cfccd8d241070bc98b3d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 19:21:09 -0500 Subject: [PATCH 1/3] Broadcast frule for `*` --- src/rulesets/Base/base.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index b36307b9d..e13a09c17 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -103,7 +103,7 @@ # product rule requires special care for arguments where `mul` is non-commutative function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy) - return x * y, Δx * y + x * Δy + return x * y, @. muladd(Δx, y, x * Δy) end function rrule(::typeof(*), x::Number, y::Number) From 9a0e8a1623c0fc2097f86d85df478bfca9f4d5e5 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 19:22:33 -0500 Subject: [PATCH 2/3] New release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index abaa65b2b..d905a5570 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.3" +version = "0.3.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From a23f0e97cbce822b5e6d3f7ddf3c8dbf8bd361ce Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 13 Jan 2020 13:46:17 -0500 Subject: [PATCH 3/3] Add comments and use explicit broadcast --- src/rulesets/Base/base.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index e13a09c17..0b5d90954 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -103,7 +103,11 @@ # product rule requires special care for arguments where `mul` is non-commutative function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy) - return x * y, @. muladd(Δx, y, x * Δy) + # Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more + # accurate on machines with FMA instructions, since there are only two + # rounding operations, one in `muladd/fma` and the other in `*`. + ∂xy = muladd.(Δx, y, x .* Δy) + return x * y, ∂xy end function rrule(::typeof(*), x::Number, y::Number)