Skip to content

Commit

Permalink
Add hcat() for tracked arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
staticfloat committed Mar 15, 2018
1 parent e931552 commit fa2717f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
44 changes: 30 additions & 14 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,22 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)

Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)

Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)

Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
# For vcat and hcat, we have quite a few methods that need to be defined for
# all the different possible combinations of types. We generate those here,
# with a correspondence between tracked types and vanilla types:
for f in [:vcat, :hcat]
for (tt, at) in [(TrackedVector, AbstractVector),
(TrackedVecOrMat, AbstractVecOrMat),
(TrackedMatrix, AbstractMatrix),
(TrackedArray, AbstractArray)]
@eval begin
import Base.$f
Base.$f(a::$tt, b::$at...) = track($f, a, b...)
Base.$f(a::$at, b::$tt...) = track($f, a, b...)
Base.$f(a::$tt, b::$tt...) = track($f, a, b...)
end
end
end

function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data)
Expand All @@ -117,6 +119,20 @@ function back(::typeof(vcat), Δ, xs...)
end
end

function back(::typeof(hcat), Δ, xs...)
start = 0
for xsi in xs
slice_idxs = Any[Colon() for dim in 1:ndims(Δ)]
slice_idxs[2] = start+1:start+size(xsi,2)
Δ_slice = Δ[slice_idxs...]
if size(xsi,2) == 1 && ndims(Δ_slice) > 1
Δ_slice = squeeze(Δ_slice, 2)
end
@back(xsi, Δ_slice)
start += size(xsi, 2)
end
end

Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
track(reshape, xs, dims...)

Expand Down
8 changes: 6 additions & 2 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(5), rand(3), rand(8))
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(hcat, rand(5), rand(5), rand(5,2))
@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5))
@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3))
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))

@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))

@test gradtest(kron,rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron,rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))

Expand Down

0 comments on commit fa2717f

Please sign in to comment.