diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 35261abec2..48228b3f45 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -81,20 +81,22 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b) -Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...) -Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b) -Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b) - -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...) -Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b) -Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) - -Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) -Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) -Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) +# For vcat and hcat, we have quite a few methods that need to be defined for +# all the different possible combinations of types. We generate those here, +# with a correspondence between tracked types and vanilla types: +for f in [:vcat, :hcat] + for (tt, at) in [(TrackedVector, AbstractVector), + (TrackedVecOrMat, AbstractVecOrMat), + (TrackedMatrix, AbstractMatrix), + (TrackedArray, AbstractArray)] + @eval begin + import Base.$f + Base.$f(a::$tt, b::$at...) = track($f, a, b...) + Base.$f(a::$at, b::$tt...) = track($f, a, b...) + Base.$f(a::$tt, b::$tt...) = track($f, a, b...) + end + end +end function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) @@ -117,6 +119,20 @@ function back(::typeof(vcat), Δ, xs...) end end +function back(::typeof(hcat), Δ, xs...) + start = 0 + for xsi in xs + slice_idxs = Any[Colon() for dim in 1:ndims(Δ)] + slice_idxs[2] = start+1:start+size(xsi,2) + Δ_slice = Δ[slice_idxs...] + if size(xsi,2) == 1 && ndims(Δ_slice) > 1 + Δ_slice = squeeze(Δ_slice, 2) + end + @back(xsi, Δ_slice) + start += size(xsi, 2) + end +end + Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = track(reshape, xs, dims...) diff --git a/test/tracker.jl b/test/tracker.jl index dee5c59417..0f9bbf409d 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -32,14 +32,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3)) +@test gradtest(hcat, rand(5), rand(5), rand(5,2)) +@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5)) +@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) -@test gradtest(kron,rand(5), rand(3)) +@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron,rand(5,1), rand(3,1)) +@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))