diff --git a/src/array_interface.jl b/src/array_interface.jl index 8359b5fa..3b7c0342 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -20,7 +20,7 @@ function Base.cat(inputs::ComponentArray...; dims::Int) combined_data = cat(getdata.(inputs)...; dims=dims) axes_to_merge = [(getaxes(i)..., FlatAxis())[dims] for i in inputs] rest_axes = [getaxes(i)[1:end .!= dims] for i in inputs] - no_duplicate_keys = (length(inputs) == 1 || isempty(intersect(keys.(axes_to_merge)...))) + no_duplicate_keys = (length(inputs) == 1 || allunique(vcat(collect.(keys.(axes_to_merge))...))) if no_duplicate_keys && length(Set(rest_axes)) == 1 offsets = (0, cumsum(size.(inputs, dims))[1:(end - 1)]...) merged_axis = Axis(merge(indexmap.(reindex.(axes_to_merge, offsets))...)) diff --git a/test/runtests.jl b/test/runtests.jl index 8560eae3..729fdc57 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -482,6 +482,7 @@ end @test ldiv!(tempmat, lu(cmat + I), cmat) isa ComponentMatrix @test ldiv!(getdata(tempmat), lu(cmat + I), cmat) isa AbstractMatrix + @test !(vcat(ca, ca2, ca) isa ComponentVector) for n in 1:3 # Issue 168 cats (on more than one) ComponentArrays vca2 = vcat(repeat([ca2'], n)...) hca2 = hcat(repeat([ca2], n)...)