From 51e7e1b40fa35949054fefa70f1ca9592821a362 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:51:04 +0200 Subject: [PATCH 01/12] cat tests #184 Co-authored-by: pevnak --- test/tracker.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/tracker.jl b/test/tracker.jl index 12ed02e5bd..2b0e04d7bd 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -32,6 +32,14 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) + +@test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) +@test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) +@test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) +@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,1)) +@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,4)) +@test gradtest((i...) -> cat(2,i...), rand(5,2),rand(5,4), rand(5,8)) + @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) From 59324c0f91a51b5de7660f4116ee5738f989cde5 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:22:59 +0200 Subject: [PATCH 02/12] hcat tests #194 Co-authored-by: Elliot Saba --- test/tracker.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/tracker.jl b/test/tracker.jl index 2b0e04d7bd..f39546eaa8 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -33,6 +33,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3)) +@test gradtest(hcat, rand(5), rand(5), rand(5,2)) +@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5)) +@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3)) @test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) @test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) @test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) @@ -45,9 +49,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) -@test gradtest(kron,rand(5), rand(3)) +@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron,rand(5,1), rand(3,1)) +@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) From 13daaec1cbafe0b74669aee345497556f7a56623 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:54:40 +0200 Subject: [PATCH 03/12] Refactored tests --- test/tracker.jl | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/test/tracker.jl b/test/tracker.jl index f39546eaa8..27d395b179 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,20 +29,25 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) -@test gradtest(vcat, rand(5), rand(3)) -@test gradtest(vcat, rand(5), rand(3), rand(8)) -@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) - -@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3)) -@test gradtest(hcat, rand(5), rand(5), rand(5,2)) -@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5)) -@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3)) -@test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) -@test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) -@test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) -@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,1)) -@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,4)) -@test gradtest((i...) -> cat(2,i...), rand(5,2),rand(5,4), rand(5,8)) +@testset "concat" begin + @testset "vcat $i" for (i,vcatf) in enumerate((vcat, (x...) -> cat(1, x...))) + @test gradtest(vcatf, rand(5), rand(3)) + @test gradtest(vcatf, rand(5), rand(3), rand(8)) + @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) + @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) + end + @testset "hcat $i" for (i,hcatf) in enumerate((hcat, (x...) -> cat(2, x...))) + @test gradtest(hcatf, rand(5), rand(5)) + @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) + @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + end + @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) + @testset "cat($dim, ...)" for dim in 1:5 + catdim = (x...) -> cat(dim, x...) + @test gradtest(catdim, rand(5), rand(5)) + @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) + end +end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From bcef5c4ab512fd84ef44a1e97c465a24f1001977 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:56:08 +0200 Subject: [PATCH 04/12] Support hcat and cat --- src/tracker/array.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 4dfb2c6d32..0bfabf3682 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -96,6 +96,14 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) +Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) +Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) +Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b) + +Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...) +Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b) +Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b) + function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -117,6 +125,34 @@ function back(::typeof(vcat), Δ, xs...) end end +function back(::typeof(hcat), Δ, xs...) + i = fill(:, ndims(Δ)-2) + start = 0 + for xsi in xs + if ndims(xsi) == 1 + @back(xsi, Δ[:, start+1]) + else + @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) + end + start += size(xsi, 2) + end +end + +function back(::typeof(cat), Δ, dim, xs...) + i = fill(:, dim-1) + j = fill(:, ndims(Δ)-dim) + start = 0 + for xsi in xs + if ndims(xsi) < dim + a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)] + @back(xsi, Δ[a..., start+1]) + else + @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) + end + start += size(xsi, dim) + end +end + Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) From eaaf5fd34c1b41b780e13d00f9ff8186e2cb4035 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Sat, 21 Apr 2018 01:10:34 +0200 Subject: [PATCH 05/12] vcat arrays with ndims>2 --- src/tracker/array.jl | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 0bfabf3682..61c2d5cede 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,20 +81,9 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b) -Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...) -Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b) -Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b) - -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...) -Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b) -Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) - -Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) -Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) -Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) +Base.vcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(vcat, a, b...) +Base.vcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(vcat, a, b) +Base.vcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(vcat, a, b) Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) @@ -117,7 +106,7 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) end function back(::typeof(vcat), Δ, xs...) - i = Base.tail(map(_ -> :, size(Δ))) + i = fill(:, ndims(Δ)-1) start = 0 for xsi in xs @back(xsi, Δ[start+1:start+size(xsi,1), i...]) From 509a2e59f6e6e8cbfcf2f3283884fe005a33fce4 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 08:30:11 +0200 Subject: [PATCH 06/12] cat promotions and mixed ranks --- src/tracker/array.jl | 32 +++++++++++++++++--------------- test/tracker.jl | 27 +++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 61c2d5cede..89fce39e9e 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,17 +81,18 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(vcat, a, b...) -Base.vcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(vcat, a, b) -Base.vcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(vcat, a, b) +for f in [:vcat, :hcat] + @eval begin + Base.$f(a::TrackedArray...) = track($f, a...) + Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) -Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) -Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) -Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b) + # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + end +end -Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...) -Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b) -Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b) +Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) +Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) @@ -106,21 +107,21 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) end function back(::typeof(vcat), Δ, xs...) - i = fill(:, ndims(Δ)-1) start = 0 for xsi in xs + i = map(_ -> :, size(xsi)) |> Base.tail @back(xsi, Δ[start+1:start+size(xsi,1), i...]) start += size(xsi, 1) end end function back(::typeof(hcat), Δ, xs...) - i = fill(:, ndims(Δ)-2) start = 0 for xsi in xs if ndims(xsi) == 1 @back(xsi, Δ[:, start+1]) else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) end start += size(xsi, 2) @@ -128,14 +129,15 @@ function back(::typeof(hcat), Δ, xs...) end function back(::typeof(cat), Δ, dim, xs...) - i = fill(:, dim-1) - j = fill(:, ndims(Δ)-dim) start = 0 for xsi in xs if ndims(xsi) < dim - a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)] - @back(xsi, Δ[a..., start+1]) + i = map(_ -> :, size(xsi)) + j = ones(Int, dim-ndims(xsi)-1) + @back(xsi, Δ[i..., j..., start+1]) else + i = fill(:, dim-1) + j = fill(:, ndims(xsi)-dim) @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) end start += size(xsi, dim) diff --git a/test/tracker.jl b/test/tracker.jl index 27d395b179..d0b2375d08 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,19 +29,42 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) +function simplepromotioncheck(f, A, B) + r0 = f(A, B) + r1 = f(param(A), B) + r2 = f(A, param(B)) + r3 = f(param(A), param(B)) + + r1 == r2 == r3 && r0 == Flux.data(r1) +end + @testset "concat" begin - @testset "vcat $i" for (i,vcatf) in enumerate((vcat, (x...) -> cat(1, x...))) + cat1(x...) = cat(1, x...) + cat2(x...) = cat(2, x...) + + @testset for vcatf in [vcat, cat1] @test gradtest(vcatf, rand(5), rand(3)) @test gradtest(vcatf, rand(5), rand(3), rand(8)) @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) + @test gradtest(vcatf, rand(5), rand(3,1)) + @test gradtest(vcatf, rand(5)', rand(2,5)) end - @testset "hcat $i" for (i,hcatf) in enumerate((hcat, (x...) -> cat(2, x...))) + + @test simplepromotioncheck(vcat, rand(5), rand(5)) + + @testset for hcatf in [hcat, cat2] @test gradtest(hcatf, rand(5), rand(5)) @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + @test gradtest(hcatf, rand(5)', rand(1,3)) + @test gradtest(hcatf, rand(5), rand(5,2)) end + + @test simplepromotioncheck(hcat, rand(5), rand(5)) + @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) + @testset "cat($dim, ...)" for dim in 1:5 catdim = (x...) -> cat(dim, x...) @test gradtest(catdim, rand(5), rand(5)) From fb685291693cf0c9dbc466557af6d3ab7f078e88 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 08:37:30 +0200 Subject: [PATCH 07/12] define back function right after forward function --- src/tracker/array.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 89fce39e9e..71a2d5304c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,19 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -for f in [:vcat, :hcat] - @eval begin - Base.$f(a::TrackedArray...) = track($f, a...) - Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - - # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector - Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) - end -end - -Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) -Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) - function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -106,6 +93,16 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) back(xs, Δ′) end +for f in [:vcat, :hcat] + @eval begin + Base.$f(a::TrackedArray...) = track($f, a...) + Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) + + # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + end +end + function back(::typeof(vcat), Δ, xs...) start = 0 for xsi in xs @@ -128,6 +125,9 @@ function back(::typeof(hcat), Δ, xs...) end end +Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) +Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) + function back(::typeof(cat), Δ, dim, xs...) start = 0 for xsi in xs From 1c189c62ed9acc2a3f6784e90dee9293f5c3b965 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 09:03:54 +0200 Subject: [PATCH 08/12] cat with multiple dims #156 Co-authored-by: americast --- src/tracker/array.jl | 31 +++++++++++++++---------------- test/tracker.jl | 2 ++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 71a2d5304c..4650e916ec 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -98,7 +98,7 @@ for f in [:vcat, :hcat] Base.$f(a::TrackedArray...) = track($f, a...) Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - # assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector + # assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) end end @@ -125,22 +125,21 @@ function back(::typeof(hcat), Δ, xs...) end end -Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...) -Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...) +Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...) +Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...) +Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...) -function back(::typeof(cat), Δ, dim, xs...) - start = 0 - for xsi in xs - if ndims(xsi) < dim - i = map(_ -> :, size(xsi)) - j = ones(Int, dim-ndims(xsi)-1) - @back(xsi, Δ[i..., j..., start+1]) - else - i = fill(:, dim-1) - j = fill(:, ndims(xsi)-dim) - @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) - end - start += size(xsi, dim) +function back(::typeof(cat), Δ, dims, Xs...) + start = ntuple(i -> 0, Val{ndims(Δ)}) + for xs in Xs + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) + + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) + + @back(xs, reshape(Δ[xs_in_Δ...],size(xs))) + + start = start .+ till_xs end end diff --git a/test/tracker.jl b/test/tracker.jl index d0b2375d08..01aa03de32 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -70,6 +70,8 @@ end @test gradtest(catdim, rand(5), rand(5)) @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) end + + @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From cfdb16e609438719e480d053d90b1be88c5a48aa Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:46:01 +0200 Subject: [PATCH 09/12] vcat test #213 Co-authored-by: improbable22 --- test/tracker.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/tracker.jl b/test/tracker.jl index 01aa03de32..7f84fdf9d3 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -72,6 +72,11 @@ end end @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) + + @testset "issue #213" begin + A, B, C = rand(2,2), rand(2,2), rand(2,2) + @test vcat(A, B, C |> param) == vcat(param.((A,B,C))...) + end end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From 94bb064a0f3ddc70774d54440ee0df312c941dea Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:47:30 +0200 Subject: [PATCH 10/12] more tests of array promotion for concatenation # Conflicts: # test/tracker.jl --- src/tracker/array.jl | 5 ++--- test/tracker.jl | 33 ++++++++++++++++++++------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 4650e916ec..1139a903dd 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -96,10 +96,9 @@ end for f in [:vcat, :hcat] @eval begin Base.$f(a::TrackedArray...) = track($f, a...) - Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...) - # assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector - Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...) + # assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)` + Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...) end end diff --git a/test/tracker.jl b/test/tracker.jl index 7f84fdf9d3..3185406a27 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -29,13 +29,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) -function simplepromotioncheck(f, A, B) - r0 = f(A, B) - r1 = f(param(A), B) - r2 = f(A, param(B)) - r3 = f(param(A), param(B)) +function promotiontest(f, A, B, C) + r0 = f(A, B, C) + r1 = f(param(A), B, C) + if ndims(A) <= 2 + r2 = f(A, param(B), C) + r3 = f(A, B, param(C)) + else + r2 = r3 = f(A, param(B), param(C)) + end + r4 = f(param(A), param(B), param(C)) - r1 == r2 == r3 && r0 == Flux.data(r1) + r1 == r2 == r3 == r4 && r0 == Flux.data(r4) end @testset "concat" begin @@ -51,18 +56,15 @@ end @test gradtest(vcatf, rand(5)', rand(2,5)) end - @test simplepromotioncheck(vcat, rand(5), rand(5)) - @testset for hcatf in [hcat, cat2] @test gradtest(hcatf, rand(5), rand(5)) @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) @test gradtest(hcatf, rand(5)', rand(1,3)) @test gradtest(hcatf, rand(5), rand(5,2)) end - @test simplepromotioncheck(hcat, rand(5), rand(5)) - @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) @testset "cat($dim, ...)" for dim in 1:5 @@ -73,9 +75,14 @@ end @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) - @testset "issue #213" begin - A, B, C = rand(2,2), rand(2,2), rand(2,2) - @test vcat(A, B, C |> param) == vcat(param.((A,B,C))...) + @testset "promotiontest" begin + @test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) + @test promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) + @test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) + @test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) + @testset "cat($dim, ...)" for dim in 1:5 + @test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5)) + end end end From 5fc61909563186e7db847ed67d84cf232e61c64b Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 14:57:32 +0200 Subject: [PATCH 11/12] RowVector tests --- src/tracker/array.jl | 20 +++++++++++++------ test/tracker.jl | 47 +++++++++++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 1139a903dd..967ce8dd07 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -95,10 +95,19 @@ end for f in [:vcat, :hcat] @eval begin - Base.$f(a::TrackedArray...) = track($f, a...) + # This section is a bit of a hack since julia doesn't have a standardised promotion mechanism for concatenation yet https://github.com/JuliaLang/julia/pull/20815 - # assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)` - Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...) + # It should support tracked concatenation with rank ∈ (1,2) with a TrackedArray anywhere among the arguments + # This works as long as base has other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. + Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) + + # It should support tracked concatenation with rank>2 if the TrackedArray is first + Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) + Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row + + # It should support tracked concatenation with rank>2 if the TrackedArray is second + Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) + Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, c::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b, c...) # resolves ambiguity introduced by previous row end end @@ -124,9 +133,8 @@ function back(::typeof(hcat), Δ, xs...) end end -Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...) -Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...) -Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...) +Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) +Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) function back(::typeof(cat), Δ, dims, Xs...) start = ntuple(i -> 0, Val{ndims(Δ)}) diff --git a/test/tracker.jl b/test/tracker.jl index 3185406a27..434148f047 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -32,15 +32,19 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) function promotiontest(f, A, B, C) r0 = f(A, B, C) r1 = f(param(A), B, C) - if ndims(A) <= 2 - r2 = f(A, param(B), C) + r2 = f(A, param(B), C) + if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] r3 = f(A, B, param(C)) else - r2 = r3 = f(A, param(B), param(C)) + @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved + r3 = r2 end r4 = f(param(A), param(B), param(C)) - r1 == r2 == r3 == r4 && r0 == Flux.data(r4) + @test !isa(r0, TrackedArray) + @test all(isa.([r1,r2,r3,r4], TrackedArray)) + @test r1 == r2 == r3 == r4 + @test r0 == Flux.data(r4) end @testset "concat" begin @@ -50,6 +54,7 @@ end @testset for vcatf in [vcat, cat1] @test gradtest(vcatf, rand(5), rand(3)) @test gradtest(vcatf, rand(5), rand(3), rand(8)) + @test gradtest(vcatf, rand(5)', rand(5)') @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) @test gradtest(vcatf, rand(5), rand(3,1)) @@ -58,31 +63,49 @@ end @testset for hcatf in [hcat, cat2] @test gradtest(hcatf, rand(5), rand(5)) + @test gradtest(hcatf, rand(5)', rand(5)') @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) @test gradtest(hcatf, rand(5)', rand(1,3)) @test gradtest(hcatf, rand(5), rand(5,2)) +end + + @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + @test gradtest(catf, rand(5)) + @test gradtest(catf, rand(5)') + @test gradtest(catf, rand(2,5)) + @test gradtest(catf, rand(2,5,3)) end @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) - @testset "cat($dim, ...)" for dim in 1:5 + @testset "cat($dim, ...)" for dim in 3:5 catdim = (x...) -> cat(dim, x...) - @test gradtest(catdim, rand(5), rand(5)) + @test gradtest(catdim, rand(5), rand(5), rand(5)) @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3)) end + @test !isa(vcat(rand(2)), TrackedArray) + @test !isa(hcat(rand(2)), TrackedArray) + @test !isa(cat(1,rand(2)), TrackedArray) + @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) @testset "promotiontest" begin - @test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) - @test promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) - @test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) - @test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) - @testset "cat($dim, ...)" for dim in 1:5 - @test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5)) + @testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + promotiontest(fcat, rand(2), rand(2), rand(2)) + promotiontest(fcat, rand(2)', rand(2)', rand(2)') + promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2)) + promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2)) end + + promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) + promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) + promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) + promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) + promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2)) end end From ef9077d9fabc7a972fddd6afcbcd454cf2efae79 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Mon, 7 May 2018 13:03:52 +0100 Subject: [PATCH 12/12] style --- src/tracker/array.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 967ce8dd07..e11296abf4 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -95,19 +95,26 @@ end for f in [:vcat, :hcat] @eval begin - # This section is a bit of a hack since julia doesn't have a standardised promotion mechanism for concatenation yet https://github.com/JuliaLang/julia/pull/20815 + # This section is a bit of a hack since julia doesn't have a standardised + # promotion mechanism for concatenation yet + # https://github.com/JuliaLang/julia/pull/20815 - # It should support tracked concatenation with rank ∈ (1,2) with a TrackedArray anywhere among the arguments - # This works as long as base has other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. + # It should support tracked concatenation with rank ∈ (1,2) with a + # TrackedArray anywhere among the arguments This works as long as base has + # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) - # It should support tracked concatenation with rank>2 if the TrackedArray is first + # It should support tracked concatenation with rank>2 if the TrackedArray is + # first Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row - # It should support tracked concatenation with rank>2 if the TrackedArray is second + # It should support tracked concatenation with rank>2 if the TrackedArray is + # second Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) - Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, c::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b, c...) # resolves ambiguity introduced by previous row + Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, + c::Union{TrackedArray,Vector,RowVector,Matrix}...) = + track($f, a, b, c...) # resolves ambiguity introduced by previous row end end