Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add predicate support for names and more tests #2417

Merged
merged 12 commits into from
Jan 31, 2021
17 changes: 12 additions & 5 deletions src/abstractdataframe/abstractdataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,23 @@ abstract type AbstractDataFrame end
"""
names(df::AbstractDataFrame)
names(df::AbstractDataFrame, cols)
names(df::AbstractDataFrame, T::Type, unionmissing::Bool=true)

Return a freshly allocated `Vector{String}` of names of columns contained in `df`.

If `cols` is passed then restrict returned column names to those matching the
selector (this is useful in particular with regular expressions, `Not`, and `Between`).
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR)
or a `Type`, in which case columns whose `eltype` is a subtype of `cols` are returned.
`cols` can be:
* any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR)
* a `Type`, in which case names of columns whose `eltype` is a subtype of `T`
are returned if `unionmissing=false` or if `unionmissing=true` then `eltype`
must be subtype of `Union{T, Missing}` other than `Missing` except if `T` is `Missing`
* a `Function` predicate, in which case names of columns for which the predicate, taking a
`String` containg column name, returns `true`
bkamins marked this conversation as resolved.
Show resolved Hide resolved

See also [`propertynames`](@ref) which returns a `Vector{Symbol}`.
"""
Base.names(df::AbstractDataFrame) = names(index(df))
Base.names(df::AbstractDataFrame, cols::Colon=:) = names(index(df))

function Base.names(df::AbstractDataFrame, cols)
nms = _names(index(df))
Expand All @@ -82,8 +88,9 @@ function Base.names(df::AbstractDataFrame, cols)
return [String(nms[i]) for i in idxs]
end

Base.names(df::AbstractDataFrame, T::Type) =
[String(n) for (n, c) in pairs(eachcol(df)) if eltype(c) <: T]
Base.names(df::AbstractDataFrame, T::Type; unionmissing::Bool=true) =
[String(n) for (n, c) in pairs(eachcol(df)) if testtype(T, eltype(c), unionmissing)]
Base.names(df::AbstractDataFrame, fun::Function) = filter(fun, names(df))
bkamins marked this conversation as resolved.
Show resolved Hide resolved

# _names returns Vector{Symbol} without copying
_names(df::AbstractDataFrame) = _names(index(df))
Expand Down
2 changes: 2 additions & 0 deletions src/abstractdataframe/iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ Base.findall(f::Function, itr::DataFrameColumns) =
Base.parent(itr::Union{DataFrameRows, DataFrameColumns}) = getfield(itr, :df)
Base.names(itr::Union{DataFrameRows, DataFrameColumns}) = names(parent(itr))
Base.names(itr::Union{DataFrameRows, DataFrameColumns}, cols) = names(parent(itr), cols)
Base.names(itr::Union{DataFrameRows, DataFrameColumns}, cols::Type; unionmissing::Bool=true) =
names(parent(itr), cols, unionmissing=unionmissing)

function Base.show(io::IO, dfrs::DataFrameRows;
allrows::Bool = !get(io, :limit, false),
Expand Down
6 changes: 5 additions & 1 deletion src/dataframerow/dataframerow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ Base.@propagate_inbounds Base.setindex!(r::DataFrameRow, value, idx) =

index(r::DataFrameRow) = getfield(r, :colindex)

Base.names(r::DataFrameRow) = names(index(r))
Base.names(r::DataFrameRow, cols::Colon=:) = names(index(r))

function Base.names(r::DataFrameRow, cols)
nms = _names(index(r))
Expand All @@ -272,6 +272,10 @@ function Base.names(r::DataFrameRow, cols)
return [string(nms[i]) for i in idxs]
end

Base.names(r::DataFrameRow, T::Type; unionmissing::Bool=true) =
[String(n) for n in _names(r) if testtype(T, eltype(parent(r)[!, n]), unionmissing)]
Base.names(r::DataFrameRow, fun::Function) = filter(fun, names(r))
bkamins marked this conversation as resolved.
Show resolved Hide resolved
bkamins marked this conversation as resolved.
Show resolved Hide resolved

_names(r::DataFrameRow) = view(_names(parent(r)), parentcols(index(r), :))

Base.haskey(r::DataFrameRow, key::Bool) =
Expand Down
20 changes: 11 additions & 9 deletions src/groupeddataframe/groupeddataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ function Base.isequal(gd1::GroupedDataFrame, gd2::GroupedDataFrame)
all(x -> isequal(x...), zip(gd1, gd2))
end

Base.names(gd::GroupedDataFrame) = names(gd.parent)
Base.names(gd::GroupedDataFrame, cols) = names(gd.parent, cols)
_names(gd::GroupedDataFrame) = _names(gd.parent)
Base.names(gd::GroupedDataFrame) = names(parent(gd))
Base.names(gd::GroupedDataFrame, cols) = names(parent(gd), cols)
Base.names(gd::GroupedDataFrame, cols::Type; unionmissing::Bool=true) =
names(parent(gd), cols, unionmissing=unionmissing)
_names(gd::GroupedDataFrame) = _names(parent(gd))

function DataFrame(gd::GroupedDataFrame; copycols::Bool=true, keepkeys::Bool=true)
if !copycols
Expand Down Expand Up @@ -254,7 +256,7 @@ meant to be constructed directly.

Indexing fields of `GroupKey` is allowed using an integer, a `Symbol`, or a string.
It is also possible to access the data in a `GroupKey` using the `getproperty`
function. A `GroupKey` can be converted to a `Tuple`, `NamedTuple`, a `Vector`, or
function. A `GroupKey` can be converted to a `Tuple`, `NamedTuple`, a `Vector`, or
a `Dict`. When converted to a `Dict`, the keys of the `Dict` are `Symbol`s.

See [`keys(::GroupedDataFrame)`](@ref) for more information.
Expand Down Expand Up @@ -402,12 +404,12 @@ end
function _dict_to_tuple(key::AbstractDict{Symbol}, gd::GroupedDataFrame)
if length(key) != length(gd.cols)
throw(KeyError(key))
end
end

return ntuple(i -> key[gd.cols[i]], length(gd.cols))
end

Base.to_index(gd::GroupedDataFrame, key::Union{AbstractDict{Symbol},AbstractDict{<:AbstractString}}) =
Base.to_index(gd::GroupedDataFrame, key::Union{AbstractDict{Symbol},AbstractDict{<:AbstractString}}) =
Base.to_index(gd, _dict_to_tuple(key, gd))

# Array of (possibly non-standard) indices
Expand Down Expand Up @@ -599,7 +601,7 @@ function Base.haskey(gd::GroupedDataFrame, key::NamedTuple{N}) where {N}
return haskey(gd, Tuple(key))
end

Base.haskey(gd::GroupedDataFrame, key::AbstractDict{<:Union{Symbol, <:AbstractString}}) =
Base.haskey(gd::GroupedDataFrame, key::AbstractDict{<:Union{Symbol, <:AbstractString}}) =
haskey(gd, _dict_to_tuple(key, gd))

Base.haskey(gd::GroupedDataFrame, key::Union{Signed,Unsigned}) =
Expand All @@ -611,8 +613,8 @@ Base.haskey(gd::GroupedDataFrame, key::Union{Signed,Unsigned}) =
Get a group based on the values of the grouping columns.

`key` may be a `GroupKey`, `NamedTuple` or `Tuple` of grouping column values (in the same
order as the `cols` argument to `groupby`). It may also be an `AbstractDict`, in which case the
order of the arguments does not matter.
order as the `cols` argument to `groupby`). It may also be an `AbstractDict`, in which case the
order of the arguments does not matter.

# Examples

Expand Down
12 changes: 12 additions & 0 deletions src/other/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,15 @@ else
end

funname(c::ComposedFunction) = Symbol(funname(c.f), :_, funname(c.g))

testtype(t::Type, ct::Type, unionmissing::Bool) =
bkamins marked this conversation as resolved.
Show resolved Hide resolved
if t == Missing
bkamins marked this conversation as resolved.
Show resolved Hide resolved
return ct === Missing
else
if unionmissing
ct === Missing && return t === Union{} || Missing <: t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Member Author

@bkamins bkamins Nov 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is good we have left some time to sleep over this PR.

The core question is (do not look at the implementation as it is just a consequence). If eltype of the column is Missing and user passes unionmissing equal to true what should happen then? The implementation was doing special casing of this case. Now I think if CT === Missing and unionmissing===true then we should always return true. Do you agree? (if yes - the above implementation will change)

EDIT: Now that I have read my comment above I am again not so sure what is best ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplest and most obvious approach would be to check ct <: Union{t, Missing} when unionmissing=true. Otherwise things get too complex IMHO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only problem is that if someone writes names(df, Float64) and has a column with eltype equal to Missing then one would get that column although it does not allow Float64 values at all.

return ct <: Union{t, Missing}
else
return ct <: t
end
end
54 changes: 44 additions & 10 deletions test/dataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1881,16 +1881,50 @@ end
@test_throws ArgumentError push!(df, "a")
end

@testset "names for Type" begin
df = DataFrame(a1 = 1:3, a2 = [1, missing, 3],
b1 = 1.0:3.0, b2 = [1.0, missing, 3.0],
c1 = '1':'3', c2 = ['1', missing, '3'])
@test names(df, Int) == ["a1"]
@test names(df, Union{Missing, Int}) == ["a1", "a2"]
@test names(df, Real) == ["a1", "b1"]
@test names(df, Union{Missing, Real}) == ["a1", "a2", "b1", "b2"]
@test names(df, Any) == names(df)
@test names(df, Union{Char, Float64, Missing}) == ["b1", "b2", "c1", "c2"]
@testset "names for Type, predicate + standard tests of cols" begin
@test DataFrames.testtype(Int, Int, true)
@test DataFrames.testtype(Int, Int, false)
@test DataFrames.testtype(Real, Int, true)
@test DataFrames.testtype(Real, Int, false)
@test !DataFrames.testtype(Union{}, Missing, false)
@test DataFrames.testtype(Union{}, Missing, true)
@test DataFrames.testtype(Int, Union{Missing, Int}, true)
@test !DataFrames.testtype(Int, Union{Missing, Int}, false)
@test DataFrames.testtype(Real, Union{Missing, Int}, true)
@test !DataFrames.testtype(Real, Union{Missing, Int}, false)
@test DataFrames.testtype(Missing, Missing, false)
@test DataFrames.testtype(Missing, Missing, true)
@test !DataFrames.testtype(Int, Missing, false)
@test !DataFrames.testtype(Int, Missing, true)
@test DataFrames.testtype(Union{Missing, Int}, Missing, false)
@test DataFrames.testtype(Union{Missing, Int}, Missing, true)

df_long = DataFrame(a1 = 1:3, a2 = [1, missing, 3],
b1 = 1.0:3.0, b2 = [1.0, missing, 3.0],
c1 = '1':'3', c2 = ['1', missing, '3'], x=1:3)
for x in (df_long[:, Not(end)], @view(df_long[:, Not(end)]),
groupby(df_long[:, Not(end)], :a1), groupby(@view(df_long[:, Not(end)]), :a1),
eachrow(df_long[:, Not(end)]), eachrow(@view(df_long[:, Not(end)])),
eachcol(df_long[:, Not(end)]), eachcol(@view(df_long[:, Not(end)])),
df_long[1, Not(end)])
@test names(x, 1) == ["a1"]
@test names(x, "a1") == ["a1"]
@test names(x, :a1) == ["a1"]
bkamins marked this conversation as resolved.
Show resolved Hide resolved
@test names(x, [2, 1]) == ["a2", "a1"]
@test names(x, ["a2", "a1"]) == ["a2", "a1"]
@test names(x, [:a2, :a1]) == ["a2", "a1"]
@test names(x, Int) == ["a1", "a2"]
@test names(x, Int, unionmissing=false) == ["a1"]
@test names(x, Union{Missing, Int}) == ["a1", "a2"]
bkamins marked this conversation as resolved.
Show resolved Hide resolved
@test names(x, Real) == ["a1", "a2", "b1", "b2"]
@test names(x, Real, unionmissing=false) == ["a1", "b1"]
@test names(x, Union{Missing, Real}) == ["a1", "a2", "b1", "b2"]
@test names(x, Any) == names(x)
@test names(x, Union{Char, Float64, Missing}) == ["b1", "b2", "c1", "c2"]
@test names(x, startswith("a")) == ["a1", "a2"]
@test names(x, :) == names(x)
@test names(x, <("a2")) == ["a1"]
end
end

end # module