diff --git a/test/tracker.jl b/test/tracker.jl index 2b0e04d7bd..f39546eaa8 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -33,6 +33,10 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(vcat, rand(5,2,3), rand(3,2,3), rand(8,2,3)) +@test gradtest(hcat, rand(5), rand(5), rand(5,2)) +@test gradtest(hcat, rand(5,2), rand(5,3), rand(5,5)) +@test gradtest(hcat, rand(5,2,3), rand(5,3,3), rand(5,5,3)) @test gradtest((i...) -> cat(1,i...), rand(5), rand(3)) @test gradtest((i...) -> cat(1,i...), rand(5), rand(8)) @test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2)) @@ -45,9 +49,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) -@test gradtest(kron,rand(5), rand(3)) +@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron,rand(5,1), rand(3,1)) +@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))