diff --git a/.travis.yml b/.travis.yml index 2028bd62f..8cc900185 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,10 +8,21 @@ julia: - 1.2 - 1.3 - nightly -matrix: +jobs: allow_failures: - - julia: 1.3 - - julia: nightly + - julia: 1.3 + - julia: nightly + include: + - stage: "Documentation" + julia: 1.0 + os: linux + script: + - julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + - julia --project=docs/ docs/make.jl + after_success: skip + +after_success: + - julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())' notifications: email: recipients: @@ -19,10 +30,3 @@ notifications: on_success: never on_failure: always if: type = cron -# uncomment the following lines to override the default test script -#script: -# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi -# - julia -e 'Pkg.clone(pwd()); Pkg.build("ChainRulesCore"); Pkg.test("ChainRulesCore"; coverage=true)' -after_success: - # push coverage results to Coveralls - - julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())' diff --git a/README.md b/README.md index 63a84c257..fc284c7aa 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ [![Coveralls](https://coveralls.io/repos/github/JuliaDiff/ChainRulesCore.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaDiff/ChainRulesCore.jl?branch=master) [![PkgEval](https://juliaci.github.io/NanosoldierReports/pkgeval_badges/C/ChainRulesCore.svg)](https://juliaci.github.io/NanosoldierReports/pkgeval_badges/report.html) -[![Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaDiff.github.io/ChainRules.jl/latest) +**Docs:** +[![](https://img.shields.io/badge/docs-master-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/dev) +[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/stable) The ChainRulesCore package provides a light-weight dependency for defining sensitivities for functions in your packages, without you needing to depend on ChainRules itself. diff --git a/docs/Manifest.toml b/docs/Manifest.toml new file mode 100644 index 000000000..a40288d7e --- /dev/null +++ b/docs/Manifest.toml @@ -0,0 +1,133 @@ +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[ChainRules]] +deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"] +git-tree-sha1 = "906cb2ae273ddbc559490117faa7abd36c98f51a" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "0.3.2" + +[[ChainRulesCore]] +deps = ["MuladdMacro"] +git-tree-sha1 = "2d67fd76f99ffba4059e55be324b24bf38582a38" +repo-rev = "ox/movedoctsD2" +repo-url = ".." +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.6.1" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[Distributed]] +deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[DocStringExtensions]] +deps = ["LibGit2", "Markdown", "Pkg", "Test"] +git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.1" + +[[Documenter]] +deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "d45c163c7a3ae293c15361acc52882c0f853f97c" +uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +version = "0.23.4" + +[[InteractiveUtils]] +deps = ["LinearAlgebra", "Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.0" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MuladdMacro]] +git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" +uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +version = "0.2.2" + +[[Parsers]] +deps = ["Dates", "Test"] +git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "0.3.10" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "999513b7dea8ac17359ed50ae8ea089e4464e35e" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.0.0" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 000000000..7f3398d5e --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,7 @@ +[deps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" + +[compat] +Documenter = "~0.23" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 000000000..8107eaf43 --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,40 @@ +using ChainRules +using ChainRulesCore +using Documenter + +@show ENV + +makedocs( + modules=[ChainRules, ChainRulesCore], + format=Documenter.HTML(prettyurls=false, assets = ["assets/chainrules.css"]), + sitename="ChainRules", + authors="Jarrett Revels and other contributors", + pages=[ + "Introduction" => "index.md", + "FAQ" => "FAQ.md", + "Writing Good Rules" => "writing_good_rules.md", + "API" => "api.md", + ], +) + +const repo = "github.com/JuliaDiff/ChainRulesCore.jl.git" +const PR = get(ENV, "TRAVIS_PULL_REQUEST", "false") +if PR == "false" + # Normal case, only deply docs if merging to master or release tagged + deploydocs(repo=repo) +else + @info "Deploying review docs for PR #$PR" + # TODO: remove most of this once https://github.com/JuliaDocs/Documenter.jl/issues/1131 is resolved + + # Overwrite Documenter's function for generating the versions.js file + foreach(Base.delete_method, methods(Documenter.Writers.HTMLWriter.generate_version_file)) + Documenter.Writers.HTMLWriter.generate_version_file(_, _) = nothing + # Overwrite necessary environment variables to trick Documenter to deploy + ENV["TRAVIS_PULL_REQUEST"] = "false" + ENV["TRAVIS_BRANCH"] = "master" + + deploydocs( + devurl="preview-PR$(PR)", + repo=repo, + ) +end diff --git a/docs/src/FAQ.md b/docs/src/FAQ.md new file mode 100644 index 000000000..db4626a3a --- /dev/null +++ b/docs/src/FAQ.md @@ -0,0 +1,69 @@ +# FAQ + +## What is up with the different symbols? + +### `Δx`, `∂x`, `dx` +ChainRules uses these perhaps atypically. +As a notation that is the same across propagators, regardless of direction (incontrast see `ẋ` and `x̄` below). + + - `Δx` is the input to a propagator, (i.e a _seed_ for a _pullback_; or a _perturbation_ for a _pushforward_) + - `∂x` is the output of a propagator + - `dx` could be either `input` or `output` + + +### dots and bars: ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}`` + - `v̇` is a derivative of the input moving forward: ``v̇ = \frac{∂v}{∂x}`` for input ``x``, intermediate value ``v``. + - `v̄` is a derivative of the output moving backward: ``v̄ = \frac{∂y}{∂v}`` for output ``y``, intermediate value ``v``. + +### others + - `Ω` is often used as the return value of the function. Especially, but not exclusively, for scalar functions. + - `ΔΩ` is thus a seed for the pullback. + - `∂Ω` is thus the output of a pushforward. + + +## Why does `rrule` return the primal function evaluation? +You might wonder why `frule(f, x)` returns `f(x)` and the derivative of `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`. +Why not just return the pushforward/pullback, and let the user call `f(x)` to get the answer separately? + +There are three reasons the rules also calculate the `f(x)`. +1. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations required to propagate the derivative. +2. For many `rrule`s the output value is used in the definition of the pullback. For example `tan`, `sigmoid` etc. +3. For some `frule`s there exists a single, non-separable operation that will compute both derivative and primal result. For example many of the methods for [differential equation sensitivity analysis](https://docs.juliadiffeq.org/stable/analysis/sensitivity/#sensitivity-1). + +## Where are the derivatives for keyword arguments? +_pullbacks_ do not return a sensitivity for keyword arguments; +similarly _pushfowards_ do not accept a perturbation for keyword arguments. +This is because in practice functions are very rarely differentiable with respect to keyword arguments. +As a rule keyword arguments tend to control side-effects, like logging verbosity, +or to be functionality changing to perform a different operation, e.g. `dims=3`, and thus not differentiable. +To the best of our knowledge no Julia AD system, with support for the definition of custom primitives, supports differentiating with respect to keyword arguments. +At some point in the future ChainRules may support these. Maybe. + + +## What is the difference between `Zero` and `DoesNotExist` ? +`Zero` and `DoesNotExist` act almost exactly the same in practice: they result in no change whenever added to anything. +Odds are if you write a rule that returns the wrong one everything will just work fine. +We provide both to allow for clearer writing of rules, and easier debugging. + +`Zero()` represents the fact that if one perturbs (adds a small change to) the matching primal there will be no change in the behavour of the primal function. +For example in `fst(x,y) = x`, then the derivative of `fst` with respect to `y` is `Zero()`. +`fst(10, 5) == 10` and if we add `0.1` to `5` we still get `fst(10, 5.1)=10`. + +`DoesNotExist()` represents the fact that if one perturbs the matching primal, the primal function will now error. +For example in `access(xs, n) = xs[n]` then the derivative of `access` with respect to `n` is `DoesNotExist()`. +`access([10, 20, 30], 2) = 20`, but if we add `0.1` to `2` we get `access([10, 20, 30], 2.1)` which errors as indexing can't be applied at fractional indexes. + + +## When to use ChainRules vs ChainRulesCore? + +[ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) is a light-weight dependency for defining rules for functions in your packages, without you needing to depend on ChainRules itself. It has no dependencies of its own. + +[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides the full functionality, in particular it has all the rules for Base Julia and the standard libraries. Its thus a much heavier package to load. + +If you only want to define rules, not use them then you probably only want to load ChainRulesCore. +AD systems making use of ChainRules should load ChainRules (rather than ChainRulesCore). + +## Where should I put my rules? +In general, we recommend adding custom sensitivities to your own packages with ChainRulesCore, rather than adding them to ChainRules.jl. + +A few packages currently SpecialFunctions.jl and NaNMath.jl are in ChainRules.jl but this is a short-term measure. diff --git a/docs/src/api.md b/docs/src/api.md new file mode 100644 index 000000000..8aaf4e044 --- /dev/null +++ b/docs/src/api.md @@ -0,0 +1,5 @@ +# API Documentation + +```@autodocs +Modules = [ChainRulesCore] +``` diff --git a/docs/src/assets/chainrules.css b/docs/src/assets/chainrules.css new file mode 100644 index 000000000..f4b0a164f --- /dev/null +++ b/docs/src/assets/chainrules.css @@ -0,0 +1,78 @@ +/* Links */ + +a { + color: #4595D1; +} + +a:hover, a:focus { + color: #194E82; +} + +/* Navigation */ + +nav.toc ul a:hover, +nav.toc ul.internal a:hover { + color: #FFFFFF; + background-color: #4595D1; +} + +nav.toc ul .toctext { + color: #FFFFFF; +} + +nav.toc { + box-shadow: none; + color: #FFFFFF; + background-color: #194E82; +} + +nav.toc li.current > .toctext { + color: #FFFFFF; + background-color: #4595D1; + border-top-width: 0px; + border-bottom-width: 0px; +} + +nav.toc ul.internal a { + color: #194E82; + background-color: #FFFFFF; +} + +/* Text */ + +article#docs a.nav-anchor { + color: #194E82; +} + +article#docs blockquote { + font-style: italic; +} + +/* Terminology Block */ + +div.admonition.terminology div.admonition-title:before { + content: "Terminology: "; + font-family: inherit; + font-weight: bold; +} +div.admonition.terminology div.admonition-title { + background-color: #FFEC8B; +} + +div.admonition.terminology div.admonition-text { + background-color: #FFFEDD; +} + +/* Code */ + +code .hljs-meta { + color: #4595D1; +} + +code .hljs-keyword { + color: #194E82; +} + +pre, code { + font-family: "Liberation Mono", "Consolas", "DejaVu Sans Mono", "Ubuntu Mono", "andale mono", "lucida console", monospace; +} diff --git a/docs/src/assets/logo.svg b/docs/src/assets/logo.svg new file mode 100644 index 000000000..c2c58f1a2 --- /dev/null +++ b/docs/src/assets/logo.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl new file mode 100644 index 000000000..63077601d --- /dev/null +++ b/docs/src/assets/make_logo.jl @@ -0,0 +1,82 @@ +using Pkg: @pkg_str +# For reproducability only dependency this has is Luxor, +# and it was created with Luxor v1.6.0 +pkg"add Luxor@v1.6" + +using Luxor +using Random + +const bridge_len = 50 + +function chain(jiggle=0) + shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) + + ### 1 + shaky_rotate(0) + sethue(Luxor.julia_red) + link() + m1 = getmatrix() + + + ### 2 + sethue(Luxor.julia_green) + translate(-50, 130); + shaky_rotate(π/3); + link() + m2 = getmatrix() + + setmatrix(m1) + sethue(Luxor.julia_red) + overlap(-1.3π) + setmatrix(m2) + + ### 3 + shaky_rotate(-π/3); + translate(-120,80); + sethue(Luxor.julia_purple) + link() + + setmatrix(m2) + setcolor(Luxor.julia_green) + overlap(-1.5π) +end + + +function link() + sector(50, 90, π, 0, :fill) + sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) + + + rect(50,-3,40, bridge_len+6, :fill) + rect(-50-40,-3,40, bridge_len+6, :fill) + + sethue("black") + move(Point(-50, bridge_len)) + arc(Point(0,0), 50, π, 0, :stoke) + arc(Point(0, bridge_len), 50, 0, -π, :stroke) + + move(Point(-90, bridge_len)) + arc(Point(0,0), 90, π, 0, :stoke) + arc(Point(0, bridge_len), 90, 0, -π, :stroke) + strokepath() +end + +function overlap(ang_end) + sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) + sethue("black") + arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) + move(Point(90, bridge_len)) + arc(Point(0, bridge_len), 90, 0, ang_end, :stoke) + + strokepath() +end + +# Actually draw it + +Random.seed!(16) +Drawing(450,450, "logo.svg") +origin() +translate(50, -130); +chain(0.5) +finish() +preview() diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 000000000..6ef034755 --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,347 @@ +```@meta +DocTestSetup = :(using ChainRulesCore, ChainRules) +``` + +# ChainRules + +[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) provides a variety of common utilities that can be used by downstream [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools to define and execute forward-, reverse-, and mixed-mode primitives. + +## Introduction + +ChainRules is all about providing a rich set of rules for differentiation. +When a person learns introductory calculus, they learn that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc. +And they learn how to combine simple rules, via [the chain rule](https://en.wikipedia.org/wiki/Chain_rule), to differentiate complicated functions. +ChainRules is a programmatic repository of that knowledge, with the generalizations to higher dimensions. + +[Autodiff (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools roughly work by reducing a problem down to simple parts that they know the rules for, and then combining those rules. +Knowing rules for more complicated functions speeds up the autodiff process as it doesn't have to break things down as much. + +**ChainRules is an AD-independent collection of rules to use in a differentiation system.** + + +!!! note "The whole field is a mess for terminology" + It isn't just ChainRules, it is everyone. + Internally ChainRules tries to be consistent. + Help with that is always welcomed. + +!!! terminology "Primal" +Often we will talk about something as _primal_. +That means it is related to the original problem, not its derivative. +For example for `y = foo(x)` +`foo` is the _primal_ function, +computing `foo(x)` is doing the _primal_ computation. +`y` is the _primal_ return, and `x` is a _primal_ argument. +`typeof(y)` and `typeof(x)` are both _primal_ types. + + +## `frule` and `rrule` + +!!! terminology "`frule` and `rrule`" + `frule` and `rrule` are ChainRules specific terms. + Their exact functioning is fairly ChainRules specific, though other tools have similar functions. + The core notion is sometimes called _custom AD primitives_, _custom adjoints_, _custom_gradients_, _custom sensitivities_. + +The rules are encoded as `frule`s and `rrule`s, for use in forward-mode and reverse-mode differentiation respectively. + +The `rrule` for some function `foo`, which takes the positional arguments `args` and keyword arguments `kwargs`, is written: + +```julia +function rrule(::typeof(foo), args...; kwargs...) + ... + return y, pullback +end +``` +where `y` (the primal result) must be equal to `foo(args...; kwargs...)`. +`pullback` is a function to propagate the derivative information backwards at that point. +That pullback function is used like: +`∂self, ∂args... = pullback(Δy)` + + +Almost always the _pullback_ will be declared locally within the `rrule`, and will be a _closure_ over some of the other arguments, and potentially over the primal result too. + +The `frule` is written: +```julia +function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...) + ... + return y, ∂Y +end +``` +where again `y = foo(args; kwargs...)`, +and `∂Y` is the result of propagating the derivative information forwards at that point. +This propagation is call the pushforward. +One could think of writing `∂Y = pushforward(Δself, Δargs)`, and often we will think of the `frule` as having the primal computation `y = foo(args...; kwargs...)`, and the push-forward `∂Y = pushforward(Δself, Δargs...)` + + +!!! note "Why `rrule` returns a pullback but `frule` doesn't return a pushforward" + While `rrule` takes only the arguments to the original function (the primal arguments) and returns a function (the pullback) that operates with the derivative information, the `frule` does it all at once. + This is because the `frule` fuses the primal computation and the pushforward. + This is an optimization that allows `frule`s to contain single large operations that perform both the primal computation and the pushforward at the same time (for example solving an ODE). +This operation is only possible in forward mode (where `frule` is used) because the derivative information needed by the pushforward available with the `frule` is invoked -- it is about the primal function's inputs. + In contrast, in reverse mode the derivative information needed by the pullback is about the primal function's output. + Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available. + + +## The propagators: pushforward and pullback + + +!!! terminology "pushforward and pullback" + + _Pushforward_ and _pullback_ are fancy words that the autodiff community recently adopted from Differential Geometry. + The are broadly in agreement with the use of [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) and [pushforward](https://en.wikipedia.org/wiki/Pushforward_(differential)) in differential geometry. + But any geometer will tell you these are the super-boring flat cases. Some will also frown at you. + They are also sometimes described in terms of the jacobian: + The _pushforward_ is _jacobian vector product_ (`jvp`), and _pullback_ is _jacobian transpose vector product_ (`j'vp`). + Other terms that may be used include for _pullback_ the **backpropagator**, and by analogy for _pushforward_ the **forwardpropagator**, thus these are the _propagators_. + These are also good names because effectively they propagate wiggles and wobbles through them, via the chain rule. + (the term **backpropagator** may originate with ["Lambda The Ultimate Backpropagator"](http://www-bcl.cs.may.ie/~barak/papers/toplas-reverse.pdf) by Pearlmutter and Siskind, 2008) + +### Core Idea + +#### Less formally + + - The **pushforward** takes a wiggle in the _input space_, and tells what wobble you would create in the output space, by passing it through the function. + - The **pullback** takes wobbliness information with respect to the function's output, and tells the equivalent wobbliness with respect to the functions input. + +#### More formally +The **pushforward** of ``f`` takes the _sensitivity_ of the input of ``f`` to a quantity, and gives the _sensitivity_ of the output of ``f`` to that quantity +The **pullback** of ``f`` takes the _sensitivity_ of a quantity to the output of ``f``, and gives the _sensitivity_ of that quantity to the input of ``f``. + +### Math +This is all a bit simplified by talking in 1D. + +#### Lighter Math +For a chain of expressions: +``` +a = f(x) +b = g(a) +c = h(b) +``` + +The pullback of `g`, which incorporates the knowledge of `∂b/∂a`, +applies the chain rule to go from `∂c/∂b` to `∂c/∂a`. + +The pushforward of `g`, which also incorporates the knowledge of `∂b/∂a`, +applies the chain rule to go from `∂a/∂x` to `∂b/∂x`. + +### Heavier Math +If I have some functions: ``g(a)``, ``h(b)`` and ``f(x)=g(h(x))``, and I know +the pullback of ``g``, at ``h(x)`` written: ``\mathrm{pullback}_{g(a)|a=h(x)}``, +and I know the derivative of ``h`` with respect to its input ``b`` at ``g(x)``, +written: ``\left.\dfrac{∂h}{∂b}\right|_{b=g(x)}`` Then I can use the pullback to +find: ``\dfrac{∂f}{∂x}``: + +``\dfrac{∂f}{∂x}=\mathrm{\mathrm{pullback}_{g(a)|a=h(x)}}\left(\left.\dfrac{∂h}{∂b}\right|_{b=g(x)}\right).`` + +If I know the derivative of ``g`` with respect to its input a at ``x``, written: +``\left.\dfrac{∂g}{∂a}\right|_{a=x}``, and I know the pushforward of ``h`` at +``g(x)`` written: ``\mathrm{pushforward}_{h(b)|b=g(x)}``. Then I can use the +pushforward to find ``\dfrac{∂f}{∂x}``: + +``\dfrac{∂f}{∂x}=\mathrm{pushforward}_{h(b)|b=g(x)}\left(\left.\dfrac{∂g}{∂a}\right|_{a=x}\right)`` + + +### The anatomy of pullback and pushforward + +For our function `foo(args...; kwargs...) = y`: + + +```julia +function pullback(Δy) + ... + return ∂self, ∂args... +end +``` + +The input to the pullback is often called the _seed_. +If the function is `y = f(x)` often the pullback will be written `s̄elf, x̄ = pullback(ȳ)`. + +!!! note + + The pullback returns one `∂arg` per `arg` to the original function, plus one `∂self` for the fields of the function itself (explained below). + +!!! terminology "perturbation, seed, sensitivity" + Sometimes _perturbation_, _seed_, and even _sensitivity_ will be used interchangeably. + They are not generally synonymous, and ChainRules shouldn't mix them up. + One must be careful when reading literature. + At the end of the day, they are all _wiggles_ or _wobbles_. + + +The pushforward is a part of the `frule` function. +Considered alone it would look like: + +```julia +function pushforward(Δself, Δargs...) + ... + return ∂y +end +``` +But because it is fused into frule we see it as part of: +```julia +function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...) + ... + return y, ∂y +end +``` + + +The input to the pushforward is often called the _perturbation_. +If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule(f, x, ṡelf, ẋ))`. +`ẏ` is commonly used to represent the perturbation for `y`. + +!!! note + + In the `frule`/pushforward, + there is one `Δarg` per `arg` to the original function. + The `Δargs` are similar in type/structure to the corresponding inputs `args` (`Δself` is explained below). + The `∂y` are similar in type/structure to the original function's output `Y`. + In particular if that function returned a tuple then `∂y` will be a tuple of the same size. + +### Self derivative `Δself`, `∂self`, `s̄elf`, `ṡelf` etc. + +!!! terminology "Δself, ∂self, s̄elf, ṡelf" + It is the derivatives with respect to the internal fields of the function. + To the best of our knowledge there is no standard terminology for this. + Other good names might be `Δinternal`/`∂internal`. + +From the mathematical perspective, one may have been wondering what all this `Δself`, `∂self` is. +Given that a function with two inputs, say `f(a, b)`, only has two partial derivatives: +``\dfrac{∂f}{∂a}``, ``\dfrac{∂f}{∂b}``. +Why then does a `pushforward` take in this extra `Δself`, and why does a `pullback` return this extra `∂self`? + +The reason is that in Julia the function `f` may itself have internal fields. +For example a closure has the fields it closes over; a callable object (i.e. a functor) like a `Flux.Dense` has the fields of that object. + +**Thus every function is treated as having the extra implicit argument `self`, which captures those fields.** +So every `pushforward` takes in an extra argument, which is ignored unless the original function has fields. +It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields. +Similarly every `pullback` returns an extra `∂self`, which for things without fields is the constant `NO_FIELDS`, indicating there are no fields within the function itself. + + +### Pushforward / Pullback summary + +- **Pullback** + - returned by `rrule` + - takes output space wobbles, gives input space wiggles + - 1 argument per original function return + - 1 return per original function argument + 1 for the function itself + +- **Pushforward:** + - part of `frule` + - takes input space wiggles, gives output space wobbles + - 1 argument per original function argument + 1 for the function itself + - 1 return per original function return + + +### Pullback/Pushforward and Directional Derivative/Gradient + +The most trivial use of the `pushforward` from within `frule` is to calculate the directional derivative: + +If we would like to know the the directional derivative of `f` for an input change of `(1.5, 0.4, -1)` + +```julia +direction = (1.5, 0.4, -1) # (ȧ, ḃ, ċ) +y, ẏ = frule(f, a, b, c, Zero(), direction) +``` + +On the basis directions one gets the partial derivatives of `y`: +```julia +y, ∂y_∂a = frule(f, a, b, c, Zero(), 1, 0, 0) +y, ∂y_∂b = frule(f, a, b, c, Zero(), 0, 1, 0) +y, ∂y_∂c = frule(f, a, b, c, Zero(), 0, 0, 1) +``` + +Similarly, the most trivial use of `rrule` and returned `pullback` is to calculate the [Gradient](https://en.wikipedia.org/wiki/Gradient): + +```julia +y, f_pullback = rrule(f, a, b, c) +∇f = f_pullback(1) # for appropriate `1`-like seed. +s̄elf, ā, b̄, c̄ = ∇f +``` +Then we have that `∇f` is the _gradient_ of `f` at `(a, b, c)`. +And we thus have the partial derivatives ``\overline{\mathrm{self}}, = \dfrac{∂f}{∂\mathrm{self}}``, ``\overline{a} = \dfrac{∂f}{∂a}``, ``\overline{b} = \dfrac{∂f}{∂b}``, ``\overline{c} = \dfrac{∂f}{∂c}``, including the and the self-partial derivative, ``\overline{\mathrm{self}}``. + +## Differentials + +The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function. +They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types. +A differential might be such a regular type, like a `Number`, or a `Matrix`, matching to the original type; +or it might be one of the `AbstractDifferential` subtypes. + +Differentials support a number of operations. +Most importantly: `+` and `*`, which let them act as mathematical objects. + +The most important `AbstractDifferential`s when getting started are the ones about avoiding work: + + - `Thunk`: this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until `unthunk` is called on the thunk. `unthunk` is a no-op on non-thunked inputs. + - `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition. + +### Other `AbstractDifferential`s: + - `Composite{P}`: this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type. + - `DoesNotExist`: Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`. + - `InplaceableThunk`: it is like a `Thunk` but it can do in-place `add!`. + + ------------------------------- + +## Example of using ChainRules directly. + +While ChainRules is largely intended as a backend for autodiff systems, it can be used directly. +In fact, this can be very useful if you can constrain the code you need to differentiate to only use things that have rules defined for. +This was once how all neural network code worked. + +Using ChainRules directly also helps get a feel for it. + +```julia +using ChainRules + +function foo(x) + a = sin(x) + b = 2a + c = asin(b) + return c +end + +#### Find dfoo/dx via rrules + +# First the forward pass, accumulating rules +x = 3; +a, a_pullback = rrule(sin, x); +b, b_pullback = rrule(*, 2, a); +c, c_pullback = rrule(asin, b) + +# Then the backward pass calculating gradients +c̄ = 1; # ∂c/∂c +_, b̄ = c_pullback(extern(c̄)); # ∂c/∂b +_, _, ā = b_pullback(extern(b̄)); # ∂c/∂a +_, x̄ = a_pullback(extern(ā)); # ∂c/∂x = ∂f/∂x +extern(x̄) +# -2.0638950738662625 + +#### Find dfoo/dx via frules + +x = 3; +ẋ = 1; # ∂x/∂x +nofields = Zero(); # ∂self/∂self + +a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x +b, ḃ = frule(*, 2, nofields, unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x + +c, ċ = frule(asin, b, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x +unthunk(ċ) +# -2.0638950738662625 + +#### Find dfoo/dx via finite-differences + +using FiniteDifferences +central_fdm(5, 1)(foo, x) +# -2.0638950738670734 + +#### Find dfoo/dx via ForwardDiff.jl +using ForwardDiff +ForwardDiff.derivative(foo, x) +# -2.0638950738662625 + +#### Find dfoo/dx via Zygote.jl +using Zygote +Zygote.gradient(foo, x) +# (-2.0638950738662625,) +``` diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md new file mode 100644 index 000000000..f5738d183 --- /dev/null +++ b/docs/src/writing_good_rules.md @@ -0,0 +1,79 @@ +# On writing good `rrule` / `frule` methods + +## Use `Zero()` or `One()` as return value + +The `Zero()` and `One()` differential objects exist as an alternative to directly returning +`0` or `zeros(n)`, and `1` or `I`. +They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. +They should be used where possible. + +## Use `Thunk`s appropriately + +If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block). + +If there are multiple return values, their computation should almost always be wrapped in a `@thunk`. + +Do _not_ wrap _variables_ in a `@thunk`; wrap the _computations_ that fill those variables in `@thunk`: + +```julia +# good: +∂A = @thunk(foo(x)) +return ∂A + +# bad: +∂A = foo(x) +return @thunk(∂A) +``` +In the bad example `foo(x)` gets computed eagerly, and all that the thunk is doing is wrapping the already calculated result in a function that returns it. + +## Be careful with using `adjoint` when you mean `transpose` + +Remember for complex numbers `a'` (i.e. `adjoint(a)`) takes the complex conjugate. +Instead you probably want `transpose(a)`, unless you've already restricted `a` to be a `AbstractMatrix{<:Real}`. + +## Code Style + +Use named local functions for the `pushforward`/`pullback`: + +```julia +# good: +function frule(::typeof(foo), x) + Y = foo(x) + function foo_pushforward(_, ẋ) + return bar(ẋ) + end + return Y, foo_pushforward +end +#== output +julia> frule(foo, 2) +(4, var"#foo_pushforward#11"()) +==# + +# bad: +function frule(::typeof(foo), x) + return foo(x), (_, ẋ) -> bar(ẋ) +end +#== output: +julia> frule(foo, 2) +(4, var"##9#10"()) +==# +``` + +While this is more verbose, it ensures that if an error is thrown during the `pullback`/`pushforward` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it. +This makes it a lot simpler to debug from the stacktrace. + +## Write tests + +There are fairly decent tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl). +They are in [`tests/test_utils.jl`](https://github.com/JuliaDiff/ChainRules.jl/blob/master/test/test_util.jl). +Take a look at existing test and you should see how to do stuff. + +!!! warning + Use finite differencing to test derivatives. + Don't use analytical derivations for derivatives in the tests. + Those are what you use to define the rules, and so can not be confidently used in the test. + If you misread/misunderstood them, then your tests/implementation will have the same mistake. + +## CAS systems are your friends. + +It is very easy to check gradients or derivatives with a computer algebra system (CAS) like [WolframAlpha](https://www.wolframalpha.com/input/?i=gradient+atan2%28x%2Cy%29).