From aa64d2157d6e7bb4a27dacda03582d6b686c854e Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 5 Feb 2019 11:38:27 +0200 Subject: [PATCH 1/4] Fixed issue #542. Added tracking of LinearAlgebra.det and its grad method. --- src/tracker/lib/array.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 3f47e89ce4..023f5fa2d6 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -1,7 +1,7 @@ import Base: * import LinearAlgebra -import LinearAlgebra: inv, \, / +import LinearAlgebra: inv, det, \, / using Statistics using LinearAlgebra: Transpose, Adjoint, diagm, diag @@ -124,6 +124,9 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),) @grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),) +det(xs::TrackedArray) = track(det, xs) +@grad det(xs) = det(data(xs)), Δ -> (Δ * transpose(adjoint(xs)),) + Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) From f790fff59ae5230eda2bd1fc081fdbb5b723bbe4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 5 Feb 2019 14:36:28 +0200 Subject: [PATCH 2/4] Use other definition for grad(det(A)). --- src/tracker/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 023f5fa2d6..e8239aadf3 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -125,7 +125,7 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),) det(xs::TrackedArray) = track(det, xs) -@grad det(xs) = det(data(xs)), Δ -> (Δ * transpose(adjoint(xs)),) +@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) From e00ac88016ebdfde5b6d01cc9dfb951c21c5b7e1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 8 Feb 2019 09:55:33 +0200 Subject: [PATCH 3/4] Added tracking of `logdet` and `logabsdet`. Added gradtests. --- src/tracker/lib/array.jl | 8 +++++++- test/tracker.jl | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index e8239aadf3..8e931acd2a 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -1,7 +1,7 @@ import Base: * import LinearAlgebra -import LinearAlgebra: inv, det, \, / +import LinearAlgebra: inv, det, logdet, logabsdet, \, / using Statistics using LinearAlgebra: Transpose, Adjoint, diagm, diag @@ -127,6 +127,12 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) det(xs::TrackedArray) = track(det, xs) @grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),) +logdet(xs::TrackedArray) = track(logdet, xs) +@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),) + +logabsdet(xs::TrackedArray) = track(logabsdet, xs) +@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),) + Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) diff --git a/test/tracker.jl b/test/tracker.jl index bb64f01a7e..ad9795c832 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff using NNlib: conv, depthwiseconv using Printf: @sprintf -using LinearAlgebra: diagm, dot, LowerTriangular, norm +using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet using Statistics: mean, std using Random # using StatsBase @@ -34,6 +34,10 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @test gradtest(x -> x', rand(5)) +@test gradtest(det, (4, 4)) +@test gradtest(logdet, (4, 4)) +@test gradtest((x) -> logabsdet(x)[1], (4, 4)) + @testset "indexing & slicing" begin gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) end From 647179081914ab34f291ed88a526300a8fbf3ac4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 8 Feb 2019 12:22:08 +0200 Subject: [PATCH 4/4] Pass symmetric matrix to `logdet` gradtest --- test/tracker.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tracker.jl b/test/tracker.jl index ad9795c832..edc81d0210 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -35,7 +35,7 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @test gradtest(x -> x', rand(5)) @test gradtest(det, (4, 4)) -@test gradtest(logdet, (4, 4)) +@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1]) @test gradtest((x) -> logabsdet(x)[1], (4, 4)) @testset "indexing & slicing" begin