diff --git a/Manifest.toml b/Manifest.toml index fc5edf9865..91ed508ac3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,3 +1,5 @@ +# This file is machine-generated - editing it directly is not advised + [[AbstractTrees]] deps = ["Markdown", "Test"] git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" @@ -27,9 +29,9 @@ version = "0.5.3" [[CodecZlib]] deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] -git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" +git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.5.1" +version = "0.5.2" [[ColorTypes]] deps = ["FixedPointNumbers", "Random", "Test"] @@ -51,9 +53,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a" +git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "1.5.1" +version = "2.0.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] @@ -71,18 +73,18 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" [[DiffResults]] deps = ["Compat", "StaticArrays"] -git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7" +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "0.0.3" +version = "0.0.4" [[DiffRules]] deps = ["Random", "Test"] -git-tree-sha1 = "09d69da75967ec48a8b1ad0897ec9144ee052bf9" +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "0.0.8" +version = "0.0.10" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FixedPointNumbers]] @@ -93,12 +95,12 @@ version = "0.5.3" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] -git-tree-sha1 = "e393bd3b9102659fb24fe88caedec41f2bc2e7de" +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.2" +version = "0.10.3" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] @@ -122,9 +124,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[MacroTools]] deps = ["Compat"] -git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b" +git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.4.4" +version = "0.4.5" [[Markdown]] deps = ["Base64"] @@ -147,7 +149,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "5a8ed87d61b1ccb71d99235c2a96287addebbb9f" +git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd" repo-rev = "master" repo-url = "https://github.com/FluxML/NNlib.jl.git" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -228,9 +230,9 @@ version = "0.7.2" [[StaticArrays]] deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] -git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898" +git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.10.2" +version = "0.10.3" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -238,19 +240,25 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598" +git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.27.0" +version = "0.29.0" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] +git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.1.0" + [[TranscodingStreams]] deps = ["Pkg", "Random", "Test"] -git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" +git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.8.1" +version = "0.9.1" [[URIParser]] deps = ["Test", "Unicode"] @@ -259,7 +267,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random"] +deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] diff --git a/Project.toml b/Project.toml index 331d683991..ebb2670108 100644 --- a/Project.toml +++ b/Project.toml @@ -6,21 +6,18 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" -DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" diff --git a/REQUIRE b/REQUIRE index edfe56bb57..455a1c1578 100644 --- a/REQUIRE +++ b/REQUIRE @@ -10,9 +10,3 @@ ZipFile AbstractTrees Reexport StatsBase - -# AD -ForwardDiff 0.5.0 -DiffRules -SpecialFunctions -NaNMath diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 0bb294e1d8..6a5c6ca221 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,3 +1,5 @@ +# This file is machine-generated - editing it directly is not advised + [[AbstractTrees]] deps = ["Markdown", "Test"] git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" @@ -6,9 +8,9 @@ version = "0.2.1" [[Adapt]] deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06" +git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "0.4.1" +version = "0.4.2" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -27,9 +29,9 @@ version = "0.5.3" [[CodecZlib]] deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] -git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" +git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.5.1" +version = "0.5.2" [[ColorTypes]] deps = ["FixedPointNumbers", "Random", "Test"] @@ -51,9 +53,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "ec61a16eed883ad0cfa002d7489b3ce6d039bb9a" +git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "1.4.0" +version = "2.0.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] @@ -71,18 +73,18 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" [[DiffResults]] deps = ["Compat", "StaticArrays"] -git-tree-sha1 = "db8acf46717b13d6c48deb7a12007c7f85a70cf7" +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "0.0.3" +version = "0.0.4" [[DiffRules]] deps = ["Random", "Test"] -git-tree-sha1 = "c49ec69428ffea0c1d1bbdc63d1a70f5df5860ad" +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "0.0.7" +version = "0.0.10" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] @@ -93,9 +95,9 @@ version = "0.6.0" [[Documenter]] deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "LibGit2", "Logging", "Markdown", "Pkg", "REPL", "Random", "Test", "Unicode"] -git-tree-sha1 = "a6db1c69925cdc53aafb38caec4446be26e0c617" +git-tree-sha1 = "a8c41ba3d0861240dbec942ee1d0f86c57c37c1c" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.21.0" +version = "0.21.5" [[FixedPointNumbers]] deps = ["Test"] @@ -104,26 +106,26 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.5.3" [[Flux]] -deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Test", "Tracker", "ZipFile"] path = ".." uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.6.10+" +version = "0.7.3+" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] -git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e" +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.1" +version = "0.10.3" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "3c29a199713e7ec62cfdc11f44d7760219d5f658" +git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.5.3" +version = "0.5.4" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -140,9 +142,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[MacroTools]] deps = ["Compat"] -git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b" +git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.4.4" +version = "0.4.5" [[Markdown]] deps = ["Base64"] @@ -156,9 +158,9 @@ version = "0.5.0" [[Missings]] deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] -git-tree-sha1 = "adc26d2ee85a49c413464110d922cf21efc9d233" +git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.3.1" +version = "0.4.0" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -244,9 +246,9 @@ version = "0.7.2" [[StaticArrays]] deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] -git-tree-sha1 = "1eb114d6e23a817cd3e99abc3226190876d7c898" +git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.10.2" +version = "0.10.3" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -254,19 +256,25 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "7b596062316c7d846b67bf625d5963a832528598" +git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.27.0" +version = "0.29.0" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] +git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.1.0" + [[TranscodingStreams]] deps = ["Pkg", "Random", "Test"] -git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" +git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.8.1" +version = "0.9.1" [[URIParser]] deps = ["Test", "Unicode"] @@ -275,7 +283,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random"] +deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] diff --git a/src/Flux.jl b/src/Flux.jl index c806716d94..3a862fb5d8 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,9 +12,8 @@ export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, @reexport using NNlib -include("tracker/Tracker.jl") -using .Tracker -using .Tracker: data +using Tracker +using Tracker: data export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl deleted file mode 100644 index adceea6140..0000000000 --- a/src/tracker/Tracker.jl +++ /dev/null @@ -1,114 +0,0 @@ -module Tracker - -using MacroTools -using MacroTools: @q, @forward - -import Base: == - -export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient, - jacobian, hessian, param, back! - -tracker(x) = nothing - -istracked(x) = tracker(x) ≠ nothing -isleaf(x) = !istracked(x) || isleaf(tracker(x)) -grad(x) = grad(tracker(x)) -grad(::Nothing) = nothing -data(x) = x - -struct Call{F,As<:Tuple} - func::F - args::As -end - -Call(f::F, args::T) where {F,T} = Call{F,T}(f, args) -Call() = Call(nothing, ()) - -# When deserialising, the object_id changes -a::Call == b::Call = a.func == b.func && a.args == b.args - -@inline (c::Call)() = c.func(data.(c.args)...) - -mutable struct Tracked{T} - ref::UInt32 - f::Call - isleaf::Bool - grad::T - Tracked{T}(f::Call) where T = new(0, f, false) - Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) - Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad) -end - -istracked(x::Tracked) = true -isleaf(x::Tracked) = x.f == Call() -grad(x::Tracked) = x.grad - -track(f::Call, x) = Tracked{typeof(x)}(f) - -function _forward end - -function track(f::F, xs...; kw...) where F - y, back = _forward(f, xs...; kw...) - track(Call(back, tracker.(xs)), y) -end - -macro grad(ex) - @capture(shortdef(ex), (name_(args__) = body_) | - (name_(args__) where {T__} = body_)) || error("Need a function definition") - T == nothing && (T = []) - isexpr(name, :(::)) || (name = :(::typeof($name))) - insert!(args, 1+isexpr(args[1], :parameters) , name) - @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc -end - -include("idset.jl") -include("params.jl") -include("back.jl") -include("numeric.jl") -include("lib/real.jl") -include("lib/array.jl") -include("forward.jl") - -""" - hook(f, x) -> x′ - -Hook into gradient backpropagation. `x` is unmodified, but when backpropagating -`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse -the sign of the gradient applied to `x`. -""" -hook(f, x) = istracked(x) ? track(hook, f, x) : x -@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ)) - -""" - checkpoint(f, args...) - -Behaves like `f(args...)`, but avoids storing the intermediate values needed for -calculating gradients. Instead, `f(args...)` will be called again during the -backward pass. This can be used to save memory in larger models. -""" -checkpoint(f, args...) = track(checkpoint, f, args...) - -@grad function checkpoint(f, args...) - data(f(args...)), function (Δ) - y, back = forward(f, args...) - (nothing, back(Δ)...) - end -end - -nobacksies(f, x) = track(nobacksies, f, x) -nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) -@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f") -@grad nobacksies(f::String, x) = data(x), Δ -> error(f) - -param(x::Number) = TrackedReal(float(x)) -param(xs::AbstractArray) = TrackedArray(float.(xs)) - -@grad identity(x) = data(x), Δ -> (Δ,) -param(x::TrackedReal) = track(identity, x) -param(x::TrackedArray) = track(identity, x) - -import Adapt: adapt, adapt_structure - -adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs))) - -end diff --git a/src/tracker/back.jl b/src/tracker/back.jl deleted file mode 100644 index 2825a92ca5..0000000000 --- a/src/tracker/back.jl +++ /dev/null @@ -1,190 +0,0 @@ -# The AD generates fairly large backtraces that are unhelpful if you interrupt -# while training; this just cleans that up. -macro interrupts(ex) - :(try $(esc(ex)) - catch e - e isa InterruptException || rethrow() - throw(e) - end) -end - -# In-place gradients - -init_grad(x) = zero(x) -zero_grad!(x) = zero(x) -zero_grad!(x::AbstractArray) = (x .= 0) - -scan(c::Call) = foreach(scan, c.args) - -function scan(x::Tracked) - x.isleaf && return - ref = x.ref += 1 - if ref == 1 - scan(x.f) - isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) - end - return -end - -function scan(x) - istracked(x) && scan(tracker(x)) - return -end - -function back_(c::Call, Δ, once) - Δs = c.func(Δ) - (Δs isa Tuple && length(Δs) >= length(c.args)) || - error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, d) -> back(x, d, once), c.args, data.(Δs)) -end - -back_(::Call{Nothing}, Δ, once) = nothing -back_(::Call{Missing}, Δ, once) = error("`back!` was already used") - -accum!(x, Δ) = x .+ Δ -accum!(x::AbstractArray, Δ) = (x .+= Δ) - -function back(x::Tracked, Δ, once) - x.isleaf && (x.grad = accum!(x.grad, Δ); return) - ref = x.ref -= 1 - grad = if isdefined(x, :grad) - x.grad = accum!(x.grad, Δ) - elseif ref > 0 - x.grad = Δ - else - Δ - end - if ref == 0 - back_(x.f, grad, once) - once && !x.isleaf && (x.f = Call(missing, ())) - end - return -end - -back(::Nothing, Δ, once) = return - -# Interface methods - -# TODO: if an error occurs in `back` the refcounts will be broken -# and `back` will silently fail to update. -# (but only if you re-use intermediate values between passes) -# Refcounts are also probably not safe in some situations (e.g. back called -# from within a backpropagator) - -function back!(x, Δ; once = true) - istracked(x) || return - scan(x) - back(tracker(x), Δ, once) - return -end - -function extract_grad!(x) - x̄ = copy(grad(x)) - x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) - tracker(x).grad = zero_grad!(grad(x)) - return x̄ -end - -function gradient_(f, xs...) - xs = param.(data.(xs)) - l = f(xs...) - losscheck(l) - @interrupts back!(l) - extract_grad!.(xs) -end - -function gradient_(f, xs::Params) - l = f() - losscheck(l) - @interrupts back!(l) - gs = Grads() - for x in xs - gs[tracker(x)] = extract_grad!(x) - end - return gs -end - -# Out-of-place gradients - -function back_(g::Grads, c::Call, Δ) - Δs = c.func(Δ) - (Δs isa Tuple && length(Δs) >= length(c.args)) || - error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) -end - -back_(g::Grads, ::Call{Nothing}, Δ) = nothing - -function back(g::Grads, x::Tracked, Δ) - x.isleaf && (accum!(g, x, Δ); return) - ref = x.ref -= 1 - if ref > 0 || haskey(g, x) - accum!(g, x, Δ) - ref == 0 && back_(g, x.f, g[x]) - else - ref == 0 && back_(g, x.f, Δ) - end - return -end - -back(::Grads, ::Nothing, _) = return - -collectmemaybe(xs) = xs - -function forward(f, ps::Params) - y = collectmemaybe(f()) - y, function (Δ) - g = Grads(ps) - if istracked(y) - scan(y) - back(g, tracker(y), Δ) - end - return g - end -end - -function forward(f, args...) - args = param.(args) - y, back = forward(() -> f(args...), Params(args)) - y, Δ -> getindex.(Ref(back(Δ)), args) -end - -function losscheck(x) - x isa Real || error("Function output is not scalar") - isinf(x) && error("Loss is infinite") - isnan(x) && error("Loss is NaN") -end - -function gradient_nested(f, args...) - y, back = forward(f, args...) - losscheck(y) - return back(1) -end - -gradient(f, xs...; nest = false) = - nest ? gradient_nested(f, xs...) : gradient_(f, xs...) - -# Jacobians and Hessians - -import ..Flux - -""" - J = jacobian(m,x) - -Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])` -""" -function jacobian(m,x) - xp = param(x) - y = m(xp) - k = length(y) - n = length(x) - J = Matrix{eltype(x)}(undef,k,n) - for i = 1:k - Flux.back!(y[i], once = false) # Populate gradient accumulator - J[i,:] = xp.grad - xp.grad .= 0 # Reset gradient accumulator - end - J -end - -hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x) diff --git a/src/tracker/forward.jl b/src/tracker/forward.jl deleted file mode 100644 index ccf75c70f9..0000000000 --- a/src/tracker/forward.jl +++ /dev/null @@ -1,53 +0,0 @@ -using ForwardDiff - -seed(x::Real, ::Val) = Dual(x, true) - -function seed(x, ::Val{N}, offset = 0) where N - map(x, reshape(1:length(x), size(x))) do x, i - Dual(x, ntuple(j -> j+offset == i, Val(N))) - end -end - -extract(x::ForwardDiff.Dual) = x.value, [x.partials...] - -function extract(xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N} - J = similar(xs, V, N, length(xs)) - for i = 1:length(xs), j = 1:N - J[j, i] = xs[i].partials.values[j] - end - return map(x -> x.value, xs), J -end - -function forward_jacobian(f, x, ::Val{N}) where N - y, _J = extract(f(seed(x, Val(N)))) - J = similar(_J, length(x), length(y)) - J[1:N,:] = _J - offset = 0 - while offset + N < length(x) - offset += N - _, _J = extract(f(seed(x, Val(N), offset))) - range = (1+offset):min(N+offset,length(x)) - J[range,:] = @view _J[range.-offset,:] - end - return y, J -end - -function forward_jacobian(f, x) - if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD - forward_jacobian(f, x, Val(length(x))) - else - forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD)) - end -end - -forwarddiff(f, x) = istracked(x) ? track(forwarddiff, f, x) : f(x) - -vec_scalar(x) = vec(x) -vec_scalar(x::Real) = [x] -reshape_scalar(x, y) = reshape(y, size(x)) -reshape_scalar(x::Real, y) = y[] - -@grad function forwarddiff(f, x) - y, J = forward_jacobian(f, data(x)) - return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ))) -end diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl deleted file mode 100644 index 372e262a0b..0000000000 --- a/src/tracker/idset.jl +++ /dev/null @@ -1,28 +0,0 @@ -struct IdSet{T} <: AbstractSet{T} - dict::IdDict{T,Nothing} - IdSet{T}() where T = new(IdDict{T,Nothing}()) -end - -Base.eltype(::IdSet{T}) where T = T - -IdSet() = IdSet{Any}() - -Base.push!(s::IdSet) = s -Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) -Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) -Base.in(x, s::IdSet) = haskey(s.dict, x) - -IdSet{T}(xs) where T = push!(IdSet{T}(), xs...) - -IdSet(xs) = IdSet{eltype(xs)}(xs) - -Base.collect(s::IdSet) = Base.collect(keys(s.dict)) -Base.similar(s::IdSet, T::Type) = IdSet{T}() - -@forward IdSet.dict Base.length - -function Base.iterate(v::IdSet, state...) - y = Base.iterate(keys(v.dict), state...) - y === nothing && return nothing - return (y[1], y[2]) -end diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl deleted file mode 100644 index e09e99bc3f..0000000000 --- a/src/tracker/lib/array.jl +++ /dev/null @@ -1,521 +0,0 @@ -import Base: * - -import LinearAlgebra -import LinearAlgebra: inv, det, logdet, logabsdet, \, / - -using Statistics -using LinearAlgebra: Transpose, Adjoint, diagm, diag - -struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} - tracker::Tracked{A} - data::A - grad::A - TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data) - TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) -end - -data(x::TrackedArray) = x.data -tracker(x::TrackedArray) = x.tracker - -TrackedVector{T,A} = TrackedArray{T,1,A} -TrackedMatrix{T,A} = TrackedArray{T,2,A} -TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} - -track(c::Call, x::AbstractArray) = TrackedArray(c, x) - -TrackedArray(c::Call, x::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x) - -TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ) - -TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x)) - -Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} - -Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x - -Base.convert(::Type{<:TrackedArray}, x::TrackedArray) = - error("Not implemented: convert $(typeof(x)) to $T") - -Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} = - TrackedArray(convert(A, x)) - -Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = - @isdefined(A) ? - print(io, "TrackedArray{…,$A}") : - invoke(show, Tuple{IO,DataType}, io, t) - -function Base.summary(io::IO, x::TrackedArray) - print(io, "Tracked ") - summary(io, data(x)) -end - -Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x)) - -function Base.show(io::IO, x::TrackedArray) - show(io, data(x)) - print(io, " (tracked)") -end - -Base.copy(x::TrackedArray) = x - -Base.setindex!(xs::TrackedArray, v, i...) = - error("Can't differentiate `setindex!`") - -back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") - -function update!(x::TrackedArray, Δ) - x.data .+= data(Δ) - tracker(x).grad .= 0 - return x -end - -function update!(x::AbstractArray, Δ) - x .+= data(Δ) - return x -end - -# Fallthrough methods - -for f in :[Base.size, Base.ndims, Base.collect].args - @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) -end - -Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) = - size(data(x), i, j, is...) - -Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = - similar(data(x), dims...) - -Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) - -for op in [:(==), :≈] - @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y) - @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y)) - @eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y)) -end - -# Array Stdlib - -Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) - -@grad function getindex(xs::AbstractArray, i...) - data(xs)[i...], function (Δ) - Δ′ = zero(xs) - Δ′[i...] = data(Δ) - (nobacksies(:getindex, Δ′), map(_->nothing, i)...) - end -end - -Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) - -@grad function view(x::AbstractArray, inds...) - view(data(x), inds...), function (Δ) - grad_output = zero(x) - subgrad = view(grad_output, inds...) - subgrad[:] = data(Δ) - (nobacksies(:view, grad_output), map(_->nothing, inds)...) - end -end - -Base.:-(xs::TrackedArray) = track(-, xs) - -@grad -(xs) = -data(xs), Δ -> (-Δ,) - -Base.transpose(xs::TrackedArray) = track(transpose, xs) -Base.adjoint(xs::TrackedArray) = track(adjoint, xs) - -@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),) -@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),) - -det(xs::TrackedArray) = track(det, xs) -@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),) - -logdet(xs::TrackedArray) = track(logdet, xs) -@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),) - -logabsdet(xs::TrackedArray) = track(logabsdet, xs) -@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),) - -Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) - -@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) - repeat(data(xs), inner = inner, outer = outer), function (Δ) - Δ′ = zero(xs) - S = size(xs) - - # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ - for (dest_idx, val) in pairs(IndexCartesian(), data(Δ)) - # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then - # wrap around based on original size S. - src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] - Δ′[src_idx...] += val - end - (nobacksies(:repeat, Δ′),) - end -end - -function combinations(xs, n) - n < 1 && return [[]] - cs = combinations(xs, n-1) - [[x, c...] for x in xs, c in cs] -end - -for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat] - cnames = map(_ -> gensym(), c) - @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) = - track($f, $(cnames...), x, xs...) -end - -for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat] - cnames = map(_ -> gensym(), c) - @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T = - track($f, $(cnames...), x, xs...) -end - -for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat] - cnames = map(_ -> gensym(), c) - @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T = - track($f, $(cnames...), x, xs...) -end - -@grad function vcat(xs...) - vcat(data.(xs)...), function (Δ) - start = 0 - Δs = [begin - i = map(_ -> :, size(xsi)) |> Base.tail - d = Δ[start+1:start+size(xsi,1), i...] - start += size(xsi, 1) - d - end for xsi in xs] - return (Δs...,) - end -end - -@grad function hcat(xs...) - hcat(data.(xs)...), function (Δ) - start = 0 - Δs = [begin - d = if ndims(xsi) == 1 - Δ[:, start+1] - else - i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail - Δ[:, start+1:start+size(xsi,2), i...] - end - start += size(xsi, 2) - d - end for xsi in xs] - return (Δs...,) - end -end - -for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i) - cnames = map(_ -> gensym(), c) - @eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) = - track(cat, $(cnames...), x, xs..., dims = dims) -end - -@grad function cat(Xs...; dims) - cat(data.(Xs)..., dims = dims), function (Δ) - start = ntuple(i -> 0, Val(ndims(Δ))) - Δs = [begin - dim_xs = 1:ndims(xs) - till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) - xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) - d = reshape(Δ[xs_in_Δ...],size(xs)) - start = start .+ till_xs - d - end for xs in Xs] - return (Δs...,) - end -end - -Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) -Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) -Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) - -@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing) - -Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm) -@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing) - -Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm) -@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing) - -function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) - m1, n1 = size(mat1) - mat1_rsh = reshape(mat1,(1,m1,1,n1)) - - m2, n2 = size(mat2) - mat2_rsh = reshape(mat2,(m2,1,n2,1)) - - return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2)) -end - -Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b) -Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b) -Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) - - -inv(A::TrackedArray) = Tracker.track(inv, A) -@grad function inv(A) - return inv(Tracker.data(A)), function (Δ) - Ainv = inv(A) - ∇A = - Ainv' * Δ * Ainv' - return (∇A, ) - end -end - -# (/) rdivide -A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B) -A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B) -A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B) -@grad function (A / B) - return Tracker.data(A) / Tracker.data(B), function (Δ) - Binv = inv(B) - ∇B = - Binv' * A' * Δ * Binv' - return (Δ * Binv', ∇B) - end -end - -# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity) -A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B) -A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B) -A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B) -@grad function (A \ B) - return Tracker.data(A) \ Tracker.data(B), function (Δ) - Ainv = inv(A) - ∇A = - Ainv' * Δ * B' * Ainv' - return (∇A, Ainv' * Δ) - end -end - - -# Reductions - -Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims) -Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) - -@grad sum(xs; dims = :) = sum(data(xs), dims = dims), - Δ -> (zero(xs) .+ Δ, ) - -Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) -Base.prod(xs::TrackedArray) = track(prod, xs) -Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) - -@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,) -@grad prod(xs, dim) = prod(data(xs), dims = dim), - Δ -> (nobacksies(:sum, - reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ), - nothing) - -Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) - -Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims) - -Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims) -Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims) - -import LinearAlgebra: dot - -dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) -dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) -dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) - -@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs) - -# Hacks to get std working -Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected) -_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected)) -_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected)) - -LinearAlgebra.norm(x::TrackedArray, p::Real = 2) = - sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 - -@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),) -_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs) -_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims) - -@grad function maximum(xs; dims = dims) - maximum(data(xs), dims = dims), function (Δ) - Δ′ = zero(xs) - _, i = findmax(data(xs), dims = dims) - Δ′[i] = data(Δ) - return (nobacksies(:maximum, Δ′),) - end -end - -@grad function minimum(xs; dims = dims) - minimum(data(xs), dims = dims), function (Δ) - Δ′ = zero(xs) - _, i = findmin(data(xs), dims = dims) - Δ′[i] = data(Δ) - return (nobacksies(:minimum, Δ′),) - end -end - -# BLAS - -LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...) -@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i)) - -x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) -x::AbstractMatrix * y::TrackedMatrix = track(*, x, y) -x::TrackedMatrix * y::TrackedMatrix = track(*, x, y) - -x::TrackedMatrix * y::AbstractVector = track(*, x, y) -x::AbstractMatrix * y::TrackedVector = track(*, x, y) -x::TrackedMatrix * y::TrackedVector = track(*, x, y) - -x::TrackedVector * y::AbstractVector = track(*, x, y) -x::AbstractVector * y::TrackedVector = track(*, x, y) -x::TrackedVector * y::TrackedVector = track(*, x, y) - -@grad a::AbstractMatrix * b::AbstractVecOrMat = - data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) - -# NNlib - -using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool - -softmax(xs::TrackedArray) = track(softmax, xs) - -@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),) - -logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) - -@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) - -depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...) -depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...) -depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...) - -@grad depthwiseconv(x, w; kw...) = - depthwiseconv(data(x), data(w); kw...), - Δ -> nobacksies(:depthwiseconv, - (NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...), - NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...))) - -conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) -conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) -conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) - -@grad conv(x, w; kw...) = - conv(data(x), data(w); kw...), - Δ -> nobacksies(:conv, - (NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...), - NNlib.∇conv_filter(data.((Δ, x))...; size=size(w), kw...))) - -∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) -∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) -∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...) - -@grad ∇conv_data(x, w; kw...) = - ∇conv_data(data(x), data(w); kw...), - Δ -> nobacksies(:conv, - (NNlib.conv(data.((Δ, w))...; size=size(x), kw...), - NNlib.∇conv_filter(data.((x, Δ))...; size=size(w), kw...))) - -maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) - -@grad function maxpool(x, k; kw...) - y = maxpool(data(x), k; kw...) - y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing) -end - -meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) - -@grad function meanpool(x, k; kw...) - y = meanpool(data(x), k; kw...) - y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing) -end - -# Broadcasting - -using ForwardDiff: Dual, partials, value - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) - -unbroadcast(x::Number, Δ) = sum(Δ) -unbroadcast(x::Base.RefValue, _) = nothing - -dual(x, p) = x -dual(x::Real, p) = Dual(x, p) - -function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N} - dargs = ntuple(j -> dual(args[j], i==j), Val(N)) - return Δ * f(dargs...).partials[1] -end - -@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N} - y = broadcast(f, data.(args)...) - eltype(y) <: Real || return y - eltype(y) == Bool && return y - function back(Δ) - Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N)) - dxs = map(unbroadcast, args, Δargs) - return dxs - end - # So we can return non-tracked arrays - track(Call(back, tracker.(args)), y) -end - -using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted - -struct TrackedStyle <: BroadcastStyle end - -Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle() -Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle() - -# We have to re-build the original broadcast struct to get the appropriate array -# style. We need this primarily to support CuArrays' broadcasting fixes. -broadcast_rebuild(xs) = data(xs) - -broadcast_rebuild(bc::Broadcasted) = - broadcasted(bc.f, broadcast_rebuild.(bc.args)...) - -preprocess(x) = x - -function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle}) - bc1 = Broadcast.flatten(bc) - bc2 = Broadcast.flatten(broadcast_rebuild(bc)) - ∇broadcast(bc2.f, bc1.args...) -end - -using Requires - -# https://github.com/FluxML/Flux.jl/issues/353 -if VERSION < v"1.1.0-DEV.548" - @init Requires.isprecompiling() || @eval Base.Broadcast begin - function flatten(bc::Broadcasted{Style}) where {Style} - isflat(bc) && return bc - args = cat_nested(bc) - let makeargs = make_makeargs(bc), f = bc.f - newf = @inline function(args::Vararg{Any,N}) where N - f(makeargs(args...)...) - end - return Broadcasted{Style}(newf, args, bc.axes) - end - end - @inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) - bc = t[1] - let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f - let makeargs = make_makeargs(makeargs, bc.args) - headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) - return @inline function(args::Vararg{Any,N}) where N - args1 = makeargs(args...) - a, b = headargs(args1...), tailargs(args1...) - (f(a...), b...) - end - end - end - end - end -end diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl deleted file mode 100644 index ec57f0d3d6..0000000000 --- a/src/tracker/lib/real.jl +++ /dev/null @@ -1,160 +0,0 @@ -mutable struct TrackedReal{T<:Real} <: Real - data::T - tracker::Tracked{T} -end - -TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x))) - -data(x::TrackedReal) = x.data -tracker(x::TrackedReal) = x.tracker - -track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) - -function back!(x::TrackedReal; once = true) - isinf(x) && error("Loss is Inf") - isnan(x) && error("Loss is NaN") - return back!(x, 1, once = once) -end - -function update!(x::TrackedReal, Δ) - x.data += data(Δ) - tracker(x).grad = 0 - return x -end - -function Base.show(io::IO, x::TrackedReal) - T = get(io, :typeinfo, Any) - show(io, data(x)) - T <: TrackedReal || print(io, " (tracked)") -end - -Base.decompose(x::TrackedReal) = Base.decompose(data(x)) - -Base.copy(x::TrackedReal) = x - -Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x - -Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) - -Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = - error("Not implemented: convert tracked $S to tracked $T") - -(T::Type{<:TrackedReal})(x::Real) = convert(T, x) - -for op in [:(==), :≈, :<, :(<=)] - @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y) - @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y)) - @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) -end - -Base.eps(x::TrackedReal) = eps(data(x)) -Base.eps(::Type{TrackedReal{T}}) where T = eps(T) - -for f in :[isinf, isnan, isfinite].args - @eval Base.$f(x::TrackedReal) = Base.$f(data(x)) -end - -Base.Printf.fix_dec(x::TrackedReal, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...) - -Base.float(x::TrackedReal) = x - -Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} = - TrackedReal{promote_type(S,T)} - -using Random - -for f in :[rand, randn, randexp].args - @eval Random.$f(rng::AbstractRNG,::Type{TrackedReal{T}}) where {T} = param(rand(rng,T)) -end - -using DiffRules, SpecialFunctions, NaNMath - -for (M, f, arity) in DiffRules.diffrules() - arity == 1 || continue - @eval begin - @grad $M.$f(a::Real) = - $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) - $M.$f(a::TrackedReal) = track($M.$f, a) - end -end - -# Work around zero(π) not working, for some reason -_zero(::Irrational) = nothing -_zero(x) = zero(x) - -for (M, f, arity) in DiffRules.diffrules() - arity == 2 || continue - da, db = DiffRules.diffrule(M, f, :a, :b) - f = :($M.$f) - @eval begin - @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) - @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) - $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) - $f(a::TrackedReal, b::Real) = track($f, a, b) - $f(a::Real, b::TrackedReal) = track($f, a, b) - end -end - -# Eliminating ambiguity -import Base:^ - -^(a::TrackedReal, b::Integer) = track(^, a, b) - -# Hack for conversions - -using ForwardDiff: Dual - -(T::Type{<:Real})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) -(Dual{T,V,N})(x::Dual) where {T,V,N} = invoke(Dual{T,V,N}, Tuple{Number}, x) - -# Tuples - -struct TrackedTuple{T<:Tuple} - data::T - tracker::Tracked{T} -end - -data(xs::TrackedTuple) = xs.data -tracker(xs::TrackedTuple) = xs.tracker - -accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) -init_grad(x::Tuple) = init_grad.(x) -zero_grad!(x::Tuple) = zero_grad!.(x) - -track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) - -function Base.show(io::IO, xs::TrackedTuple) - show(io, data(xs)) - print(io, " (tracked)") -end - -Base.length(x::TrackedTuple) = length(data(x)) - -Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) - -@grad function getindex(xs::TrackedTuple, i) - data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) -end - -# Array collection - -function collect(xs) - xs = Base.collect(xs) - track(Call(collect, (tracker.(xs),)), data.(xs)) -end - -function scan(c::Call{typeof(collect)}) - foreach(scan, c.args[1]) -end - -function back_(c::Call{typeof(collect)}, Δ, once) - foreach((x, d) -> back(x, d, once), c.args[1], data(Δ)) -end - -function back_(g::Grads, c::Call{typeof(collect)}, Δ) - foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ) -end - -collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs) -collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl deleted file mode 100644 index 112117ed1d..0000000000 --- a/src/tracker/numeric.jl +++ /dev/null @@ -1,18 +0,0 @@ -function ngradient(f, xs::AbstractArray...) - grads = zero.(xs) - for (x, Δ) in zip(xs, grads), i in 1:length(x) - δ = sqrt(eps()) - tmp = x[i] - x[i] = tmp - δ/2 - y1 = f(xs...) - x[i] = tmp + δ/2 - y2 = f(xs...) - x[i] = tmp - Δ[i] = (y2-y1)/δ - end - return grads -end - -gradcheck(f, xs...) = - all(isapprox.(ngradient(f, xs...), - data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5)) diff --git a/src/tracker/params.jl b/src/tracker/params.jl deleted file mode 100644 index 7a1db1e9f7..0000000000 --- a/src/tracker/params.jl +++ /dev/null @@ -1,46 +0,0 @@ -struct Params - order::Vector{Any} - params::IdSet{Any} - Params() = new([], IdSet()) -end - -@forward Params.order Base.iterate, Base.length - -function Base.push!(ps::Params, x) - if !(x in ps.params) - push!(ps.order, x) - push!(ps.params, x) - end - return ps -end - -Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) - -Params(xs) = push!(Params(), xs...) - -function Base.show(io::IO, ps::Params) - print(io, "Params([") - join(io, ps.order, ", ") - print(io, "])") -end - -struct Grads - grads::IdDict{Any,Any} -end - -Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") - -Grads() = Grads(IdDict()) - -@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate - -Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) - -Base.getindex(g::Grads, x::Tracked) = g.grads[x] - -function Base.getindex(g::Grads, x) - istracked(x) || error("Object not tracked: $x") - g[tracker(x)] -end - -accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ diff --git a/test/tracker.jl b/test/tracker.jl index 47ce7166d4..5f3a291f4d 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,347 +1,15 @@ -using Flux -using Flux.Tracker, Test, NNlib -using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff -using NNlib: conv, ∇conv_data, depthwiseconv -using Printf: @sprintf -using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet -using Statistics: mean, std -using Random -# using StatsBase +using Flux, Test +using Tracker: gradcheck gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) -@testset "Tracker" begin -@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) -@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) -@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10)) -@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5)) -@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5)) -@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3)) -@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3)) -@test gradtest(x -> sum(x), randn(Float64,2,3)) -@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5)) -@test gradtest(x -> prod(x), (3,4,5)) -@test gradtest(x -> softmax(x).*(1:3), 3) -@test gradtest(x -> softmax(x).*(1:3), (3,5)) -@test gradtest(x -> logsoftmax(x).*(1:3), 3) -@test gradtest(x -> logsoftmax(x).*(1:3), (3,5)) +@testset "Tracker" begin @test gradtest(Flux.mse, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) -@test gradtest(x -> x', rand(5)) - -@test gradtest(det, (4, 4)) -@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1]) -@test gradtest((x) -> logabsdet(x)[1], (4, 4)) - -@testset "indexing & slicing" begin - gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) -end - -function promotiontest(f, A, B, C) - r0 = f(A, B, C) - r1 = f(param(A), B, C) - r2 = f(A, param(B), C) - r3 = f(A, B, param(C)) - r4 = f(param(A), param(B), param(C)) - - @test !isa(r0, TrackedArray) - @test all(isa.([r1,r2,r3,r4], TrackedArray)) - @test r1 == r2 == r3 == r4 - @test r0 == Flux.data(r4) -end - -@testset "concat" begin - cat1(x...) = cat(x..., dims = 1) - cat2(x...) = cat(x..., dims = 2) - - @testset for vcatf in [vcat, cat1] - @test gradtest(vcatf, rand(5), rand(3)) - @test gradtest(vcatf, rand(5), rand(3), rand(8)) - @test gradtest(vcatf, rand(5)', rand(5)') - @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) - @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) - @test gradtest(vcatf, rand(5), rand(3,1)) - @test gradtest(vcatf, rand(5)', rand(2,5)) - end - - - @testset for hcatf in [hcat, cat2] - @test gradtest(hcatf, rand(5), rand(5)) - @test gradtest(hcatf, rand(5)', rand(5)') - @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) - @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) - @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) - @test gradtest(hcatf, rand(5)', rand(1,3)) - @test gradtest(hcatf, rand(5), rand(5,2)) -end - - @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))] - @test gradtest(catf, rand(5)) - @test gradtest(catf, rand(5)') - @test gradtest(catf, rand(2,5)) - @test gradtest(catf, rand(2,5,3)) - end - - @test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4)) - - @testset "cat($dim, ...)" for dim in 3:5 - catdim = (x...) -> cat(x..., dims = dim) - @test gradtest(catdim, rand(5), rand(5), rand(5)) - @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) - @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3)) - end - - @test !isa(vcat(rand(2)), TrackedArray) - @test !isa(hcat(rand(2)), TrackedArray) - @test !isa(cat(rand(2), dims=1), TrackedArray) - - @test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1)) - - @testset "promotiontest" begin - @testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))] - promotiontest(fcat, rand(2), rand(2), rand(2)) - promotiontest(fcat, rand(2)', rand(2)', rand(2)') - promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2)) - promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2)) - end - - promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) - promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) - promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) - promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) - promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2)) - end - - @testset "scalars" begin - @test vcat(param([1, 2, 3]), 1) isa TrackedArray - @test vcat(1, param([1, 2, 3])) isa TrackedArray - @test hcat(1, param([1 2 3;])) isa TrackedArray - @test vcat(param(1), 2) isa TrackedArray - end - -end - -@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) -@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6)) - -@test gradtest(x -> repeat(x; inner=2), rand(5)) -@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) -@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) - -@test gradtest(kron, rand(5), rand(3)) -@test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron, rand(5,1), rand(3,1)) -@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) -@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) - -@test gradtest(x -> diagm(0 => x), rand(3)) - -@test gradtest(W -> inv(log.(W * W)), (5,5)) -@test gradtest((A, B) -> A / B , (1,5), (5,5)) -@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5)) -@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5)) - -@testset "mean" begin - @test gradtest(mean, rand(2, 3)) - - @test gradtest(x -> mean(x, dims=1), rand(2, 3)) - @test gradtest(x -> mean(x, dims=2), rand(2, 3)) - @test gradtest(x -> mean(x, dims=3), rand(2, 3, 4)) - - @test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4)) -end - -@testset "maximum" begin - @test gradtest(maximum, rand(2, 3)) - - @test gradtest(x -> maximum(x, dims=1), rand(2, 3)) - @test gradtest(x -> maximum(x, dims=2), rand(2, 3)) - @test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4)) - - @test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) -end - -@testset "minimum" begin - @test gradtest(minimum, rand(2, 3)) - - @test gradtest(x -> minimum(x, dims=1), rand(2, 3)) - @test gradtest(x -> minimum(x, dims=2), rand(2, 3)) - @test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4)) - - @test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4)) -end - -@test gradtest(x -> std(x), rand(5,5)) -@test gradtest(x -> std(x, dims = 1), rand(5,5)) -@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5)) - @test gradtest(x -> Flux.normalise(x), rand(4,3)) @test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4)) -@test gradtest((x, y) -> x .* y, rand(5), rand(5)) -@test gradtest(dot, rand(5), rand(5)) - -@test gradtest(norm, rand(5)) - -@test gradtest(rand(5)) do x - y = x.^2 - 2y + x end - -@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2)) -@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2)) -@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2)) - -@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3)) -@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3)) -@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3)) - -@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3)) - -@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3)) -@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3)) -@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3)) - -@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) -@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) - -@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) -@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) - -@test gradtest(x -> Float64.(x), 5) - -@testset "equality & order" begin - # TrackedReal - @test param(2)^2 == param(4) - @test param(2)^2 == 4 - @test 4 == param(2)^2 - - @test param(2)^2 ≈ param(4) - @test param(2)^2 ≈ 4 - @test 4 ≈ param(2)^2 - - @test (param([1,2,3]) .< 2) == [true, false, false] - @test (param([1,2,3]) .<= 2) == [true, true, false] - @test (2 .> param([1,2,3])) == [true, false, false] - @test (2 .>= param([1,2,3])) == [true, true, false] - - # TrackedArray - @test param([1,2,3]).^2 == param([1,4,9]) - @test [1,2,3].^2 == param([1,4,9]) - @test param([1,2,3]).^2 == [1,4,9] - - @test param([1,2,3]).^2 ≈ param([1,4,9]) - @test [1,2,3].^2 ≈ param([1,4,9]) - @test param([1,2,3]).^2 ≈ [1,4,9] -end - -@testset "reshape" begin - x = reshape(param(rand(2,2,2)), 4, 2) - @test x isa TrackedArray - @test size(x) == (4,2) - x = reshape(param([1]), (1,:)) - @test x isa TrackedArray - @test size(x) == (1,1) - x = reshape(param(rand(2)), (2,:)) - @test x isa TrackedArray - @test size(x) == (2,1) - x = reshape(param(rand(2,2)), (1,:,2)) - @test x isa TrackedArray - @test size(x) == (1,2,2) -end - -@testset "Intermediates" begin - x = param([1]) - l = sum((x .+ x).^2) - Flux.back!(l, once = false) - @test x.grad == [8] - x.grad .= 0 - Flux.back!(l, once = false) - @test x.grad == [8] -end - -@testset "Fallbacks" begin - xs = param([1 2; 3 4]) - @test similar(xs) isa Matrix{Float64} -end - -@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00" - -@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4)) - -b = param(rand()) -Tracker.back!(b) -@test Tracker.grad(b) == 1 - -@testset "collect" begin - x, y = param(2), param(3) - xy = Tracker.collect([x, y]) - @test xy isa TrackedArray{Float64} - z = xy[1]*xy[2] - back!(z) - @test grad.((x,y)) == (3, 2) - - @test gradient(2, 3) do x, y - xy = Tracker.collect([x, y]) - xy[1]*xy[2] - end == (3, 2) -end - -# Gradient Hooks -@testset "Hooks" begin - x = param(2) - y = Tracker.hook(-, x) - back!(y) - @test grad(x) == -1 -end - -@testset "Checkpointing" begin - count = 0 - function mul(a, b) - count += 1 - a * b - end - @test gradient(x -> mul(5, x), 3)[1] == 5 - @test count == 1 - @test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5 - @test count == 3 -end - -@testset "Updates" begin - xs = param([1, 2, 3]) - Tracker.update!(xs, param([4, 5, 6])) - @test xs == [5, 7, 9] - x = param(3) - Tracker.update!(x, param(4)) - @test x == 7 -end - -@testset "Params" begin - W = param(randn(5, 10)) - x = rand(10) - dW = gradient(W -> sum(W*x), W)[1] - gs = gradient(() -> sum(W*x), Tracker.Params([W])) - @test gs[W] == dW -end - -@testset "Forward" begin - @test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] == - reshape(ones(25), :, 1) - @test gradient([2, 3]) do x - forwarddiff(x) do x - x[1]*x[2] - end - end == ([3, 2],) -end - -@testset "Custom Sensitivities" begin - y, back = Tracker.forward(x -> [3x^2, 2x], 5) - @test back([1, 1]) == (32,) -end - -end #testset