From 08cf2084b0d4390ae6031a9f6425c88b85b7647a Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 18:56:42 -0500 Subject: [PATCH 1/7] Broadcast the `propagation_expr` for vector mode AD --- src/rule_definition_tools.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d631f6131..2a530cc0e 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -208,9 +208,19 @@ end function propagation_expr(Δs, ∂s) # This is basically Δs ⋅ ∂s ∂s = map(esc, ∂s) + n∂s = length(∂s) - ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s)) - return :(+($(∂_mul_Δs...))) + ∂_mul_Δs = ntuple(i->:($(∂s[i]) .* $(Δs[i])), n∂s) + + # avoiding the extra `+` operation, it is potentially + # expensive for vector mode AD + sumed_∂_mul_Δs = if n∂s > 1 + :(.+($(∂_mul_Δs...))) + else + ∂_mul_Δs + end + + return sumed_∂_mul_Δs end """ From 77447cfcb1f0ae9be685eb74646a1f3a3ce05395 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 19:01:27 -0500 Subject: [PATCH 2/7] Use the `muladd` macro to optimize `propagation_expr` --- Project.toml | 3 +++ src/rule_definition_tools.jl | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 18bb57a35..f8511736a 100644 --- a/Project.toml +++ b/Project.toml @@ -2,6 +2,9 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "0.5.1" +[deps] +MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" + [compat] julia = "^1.0" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 2a530cc0e..f68d77646 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,4 +1,5 @@ # These are some macros (and supporting functions) to make it easier to define rules. +using MuladdMacro: @muladd """ @scalar_rule(f(x₁, x₂, ...), @@ -220,7 +221,7 @@ function propagation_expr(Δs, ∂s) ∂_mul_Δs end - return sumed_∂_mul_Δs + return :(@muladd $sumed_∂_mul_Δs) end """ From 700b619b4e767a679a6d7d1e54e80a2c1164df6d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 19:07:08 -0500 Subject: [PATCH 3/7] New release --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f8511736a..600e1c9ed 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.5.1" +version = "0.5.2" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" [compat] julia = "^1.0" +MuladdMacro = "0.2.1" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 65b2ce27b789af574f6f1952d23b2174cd6bcf1c Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 19:10:01 -0500 Subject: [PATCH 4/7] Fix propagation_expr --- src/rule_definition_tools.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index f68d77646..ba80b42d3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -211,14 +211,14 @@ function propagation_expr(Δs, ∂s) ∂s = map(esc, ∂s) n∂s = length(∂s) - ∂_mul_Δs = ntuple(i->:($(∂s[i]) .* $(Δs[i])), n∂s) + ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) # avoiding the extra `+` operation, it is potentially # expensive for vector mode AD sumed_∂_mul_Δs = if n∂s > 1 - :(.+($(∂_mul_Δs...))) + :(@. +($(∂_mul_Δs...))) else - ∂_mul_Δs + ∂_mul_Δs[1] end return :(@muladd $sumed_∂_mul_Δs) From b7d1da2d18a1814284e5182381c6a5e699c79c84 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 13 Jan 2020 12:43:32 -0500 Subject: [PATCH 5/7] Explicit broadcast --- src/rule_definition_tools.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ba80b42d3..d4bccbcd4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -211,12 +211,12 @@ function propagation_expr(Δs, ∂s) ∂s = map(esc, ∂s) n∂s = length(∂s) - ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) + ∂_mul_Δs = ntuple(i->:($(∂s[i]) .* $(Δs[i])), n∂s) - # avoiding the extra `+` operation, it is potentially + # avoiding the extra `.+` operation, it is potentially # expensive for vector mode AD sumed_∂_mul_Δs = if n∂s > 1 - :(@. +($(∂_mul_Δs...))) + :(.+($(∂_mul_Δs...))) else ∂_mul_Δs[1] end From d2876cff6a7722048145ecad2739b47b3dc631ac Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 13 Jan 2020 13:04:31 -0500 Subject: [PATCH 6/7] Revert "Explicit broadcast" This reverts commit b7d1da2d18a1814284e5182381c6a5e699c79c84. --- src/rule_definition_tools.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d4bccbcd4..ba80b42d3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -211,12 +211,12 @@ function propagation_expr(Δs, ∂s) ∂s = map(esc, ∂s) n∂s = length(∂s) - ∂_mul_Δs = ntuple(i->:($(∂s[i]) .* $(Δs[i])), n∂s) + ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) - # avoiding the extra `.+` operation, it is potentially + # avoiding the extra `+` operation, it is potentially # expensive for vector mode AD sumed_∂_mul_Δs = if n∂s > 1 - :(.+($(∂_mul_Δs...))) + :(@. +($(∂_mul_Δs...))) else ∂_mul_Δs[1] end From 91b0be437b5234dbca02f823a97bf15cbd53215f Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 13 Jan 2020 13:29:47 -0500 Subject: [PATCH 7/7] Add comments about `at .` --- src/rule_definition_tools.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ba80b42d3..39f9d6673 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -211,13 +211,19 @@ function propagation_expr(Δs, ∂s) ∂s = map(esc, ∂s) n∂s = length(∂s) + # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression + # literals. ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) - # avoiding the extra `+` operation, it is potentially - # expensive for vector mode AD + # Avoiding the extra `+` operation, it is potentially expensive for vector + # mode AD. sumed_∂_mul_Δs = if n∂s > 1 + # we use `@.` to broadcast `*` and `+` :(@. +($(∂_mul_Δs...))) else + # Note: we don't want to do broadcasting with only 1 multiply (no `+`), + # because some arrays overload multiply with scalar. Avoiding + # broadcasting saves compilation time. ∂_mul_Δs[1] end