From 6fb2a3f2eb511f46cc92936cd0472aa67fdc9f6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Wed, 11 Sep 2024 13:46:20 +0200 Subject: [PATCH] avoid type piracy in reduce with vcat --- src/abstractdataframe/iteration.jl | 17 ++++++++++++++--- test/cat.jl | 15 +++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/abstractdataframe/iteration.jl b/src/abstractdataframe/iteration.jl index 07c6c75fbb..7954d011ea 100644 --- a/src/abstractdataframe/iteration.jl +++ b/src/abstractdataframe/iteration.jl @@ -598,6 +598,12 @@ function mapcols!(f::Union{Function,Type}, df::DataFrame; cols=All()) return df end +############################################################################## +## +## Reduction +## +############################################################################## + """ reduce(::typeof(vcat), dfs::Union{AbstractVector{<:AbstractDataFrame}, @@ -686,7 +692,10 @@ julia> reduce(vcat, [df1, df2, df3], cols=:union, source=:source) ``` """ function Base.reduce(::typeof(vcat), - dfs::Union{AbstractVector{<:AbstractDataFrame}, + dfs::Union{AbstractVector{AbstractDataFrame}, + AbstractVector{DataFrame}, + AbstractVector{SubDataFrame}, + AbstractVector{Union{DataFrame,SubDataFrame}}, Tuple{AbstractDataFrame,Vararg{AbstractDataFrame}}}; cols::Union{Symbol,AbstractVector{Symbol}, AbstractVector{<:AbstractString}}=:setequal, @@ -741,8 +750,10 @@ end # definition needed to avoid dispatch ambiguity Base.reduce(::typeof(vcat), - dfs::SentinelArrays.ChainedVector{T,A} where {T<:AbstractDataFrame, - A<:AbstractVector{T}}; + dfs::Union{SentinelArrays.ChainedVector{AbstractDataFrame,<:AbstractVector{AbstractDataFrame}}, + SentinelArrays.ChainedVector{DataFrame,<:AbstractVector{DataFrame}}, + SentinelArrays.ChainedVector{SubDataFrame,<:AbstractVector{SubDataFrame}}, + SentinelArrays.ChainedVector{Union{DataFrame,SubDataFrame},<:AbstractVector{Union{DataFrame,SubDataFrame}}}}; cols::Union{Symbol,AbstractVector{Symbol}, AbstractVector{<:AbstractString}}=:setequal, source::Union{Nothing,SymbolOrString, diff --git a/test/cat.jl b/test/cat.jl index b5aa1cfd9b..a63dcc2ccb 100644 --- a/test/cat.jl +++ b/test/cat.jl @@ -477,4 +477,19 @@ end @test reduce(vcat, (df1, df2)) == DataFrame(a=[1, 1], b=[2, 2]) end +@testset "vcat type piracy" begin + x = Int[] + @test reduce(vcat, Union{}[], init=x) === x + + @test reduce(vcat, AbstractDataFrame[DataFrame(a=1), DataFrame(a=2)]) == + DataFrame(a=[1, 2]) + @test reduce(vcat, Union{DataFrame, SubDataFrame}[DataFrame(a=1), DataFrame(a=2)]) == + DataFrame(a=[1, 2]) + @test reduce(vcat, AbstractDataFrame[DataFrame(a=1), DataFrame(a=2)]; source=:source) == + DataFrame(a=[1, 2], source=[1, 2]) + @test reduce(vcat, Union{DataFrame,SubDataFrame}[DataFrame(a=1), DataFrame(a=2)]; source=:source) == + DataFrame(a=[1, 2], source=[1, 2]) +end + + end # module