Skip to content

Commit

Permalink
Fix implementation of Tables row interface (#279)
Browse files Browse the repository at this point in the history
* Create valid rows iterator

* Test that matrices via rows and cols are the same

* Remove now-unnecessary optional overloads

* Use egal for maybe better const prop

* Better denote sections with comments

* Increment version number

* Organize tests into sections

* Add row interface tests
  • Loading branch information
sethaxen authored Mar 16, 2021
1 parent 756b7ff commit d60fb10
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 103 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "4.7.1"
version = "4.7.2"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
47 changes: 27 additions & 20 deletions src/tables.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Tables and TableTraits interface

## Chains
####
#### Chains
####

function _check_columnnames(chn::Chains)
for name in names(chn)
Expand All @@ -11,8 +13,12 @@ function _check_columnnames(chn::Chains)
end
end

#### Tables interface

Tables.istable(::Type{<:Chains}) = true

# AbstractColumns interface

Tables.columnaccess(::Type{<:Chains}) = true

function Tables.columns(chn::Chains)
Expand All @@ -26,11 +32,11 @@ function Tables.getcolumn(chn::Chains, i::Int)
return Tables.getcolumn(chn, Tables.columnnames(chn)[i])
end
function Tables.getcolumn(chn::Chains, nm::Symbol)
if nm == :iteration
if nm === :iteration
iterations = range(chn)
nchains = size(chn, 3)
return repeat(iterations, nchains)
elseif nm == :chain
elseif nm === :chain
chainids = chains(chn)
niter = size(chn, 1)
return repeat(chainids; inner = niter)
Expand All @@ -39,18 +45,13 @@ function Tables.getcolumn(chn::Chains, nm::Symbol)
end
end

Tables.rowaccess(::Type{<:Chains}) = true
# row access

function Tables.rows(chn::Chains)
_check_columnnames(chn)
return chn
end
Tables.rowaccess(::Type{<:Chains}) = true

Tables.rowtable(chn::Chains) = Tables.rowtable(Tables.columntable(chn))
Tables.rows(chn::Chains) = Tables.rows(Tables.columntable(chn))

function Tables.namedtupleiterator(chn::Chains)
return Tables.namedtupleiterator(Tables.columntable(chn))
end
# optional Tables overloads

function Tables.schema(chn::Chains)
_check_columnnames(chn)
Expand All @@ -60,17 +61,25 @@ function Tables.schema(chn::Chains)
return Tables.Schema(nms, types)
end

#### TableTraits interface

IteratorInterfaceExtensions.isiterable(::Chains) = true
function IteratorInterfaceExtensions.getiterator(chn::Chains)
return Tables.datavaluerows(Tables.columntable(chn))
end

TableTraits.isiterabletable(::Chains) = true

## ChainDataFrame
####
#### ChainDataFrame
####

#### Tables interface

Tables.istable(::Type{<:ChainDataFrame}) = true

# AbstractColumns interface

Tables.columnaccess(::Type{<:ChainDataFrame}) = true

Tables.columns(cdf::ChainDataFrame) = cdf
Expand All @@ -80,21 +89,19 @@ Tables.columnnames(::ChainDataFrame{<:NamedTuple{names}}) where {names} = names
Tables.getcolumn(cdf::ChainDataFrame, i::Int) = cdf.nt[i]
Tables.getcolumn(cdf::ChainDataFrame, nm::Symbol) = cdf.nt[nm]

Tables.rowaccess(::Type{<:ChainDataFrame}) = true

Tables.rows(cdf::ChainDataFrame) = cdf
# row access

Tables.rowtable(cdf::ChainDataFrame) = Tables.rowtable(Tables.columntable(cdf))
Tables.rowaccess(::Type{<:ChainDataFrame}) = true

function Tables.namedtupleiterator(cdf::ChainDataFrame)
return Tables.namedtupleiterator(Tables.columntable(cdf))
end
Tables.rows(cdf::ChainDataFrame) = Tables.rows(Tables.columntable(cdf))

function Tables.schema(::ChainDataFrame{NamedTuple{names,T}}) where {names,T}
types = ntuple(i -> eltype(fieldtype(T, i)), fieldcount(T))
return Tables.Schema(names, types)
end

#### TableTraits interface

IteratorInterfaceExtensions.isiterable(::ChainDataFrame) = true
function IteratorInterfaceExtensions.getiterator(cdf::ChainDataFrame)
return Tables.datavaluerows(Tables.columntable(cdf))
Expand Down
221 changes: 139 additions & 82 deletions test/tables_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,95 @@ using DataFrames

@testset "Tables interface" begin
@test Tables.istable(typeof(chn))
@test Tables.columnaccess(typeof(chn))
@test Tables.columns(chn) === chn
@test Tables.columnnames(chn) ==
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
@test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000]
@test Tables.getcolumn(chn, :chain) ==
[fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)]
@test Tables.getcolumn(chn, :a) == [
vec(chn[:, :a, 1])
vec(chn[:, :a, 2])
vec(chn[:, :a, 3])
vec(chn[:, :a, 4])
]
@test_throws Exception Tables.getcolumn(chn, :j)
@test Tables.getcolumn(chn, 1) == Tables.getcolumn(chn, :iteration)
@test Tables.getcolumn(chn, 2) == Tables.getcolumn(chn, :chain)
@test Tables.getcolumn(chn, 3) == Tables.getcolumn(chn, :a)
@test_throws Exception Tables.getcolumn(chn, :i)
@test_throws Exception Tables.getcolumn(chn, 11)
@test Tables.rowaccess(typeof(chn))
@test Tables.rows(chn) === chn
@test length(Tables.rowtable(chn)) == 4000
nt = Tables.rowtable(chn)[1]
@test nt ==
(; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1]
nt = Tables.rowtable(chn)[2]
@test nt ==
(; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2]
@test Tables.schema(chn) isa Tables.Schema
@test Tables.schema(chn).names ===
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
@test Tables.schema(chn).types === (
Int,
Int,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
)
@test Tables.matrix(chn[:, :, 1])[:, 3:end] chn[:, :, 1].value
@test Tables.matrix(chn[:, :, 2])[:, 3:end] chn[:, :, 2].value

val = rand(1000, 2, 4)
chn2 = Chains(val, ["iteration", "a"])
@test_throws Exception Tables.columns(chn2)
@test_throws Exception Tables.rows(chn2)
@test_throws Exception Tables.schema(chn2)
chn3 = Chains(val, ["chain", "a"])
@test_throws Exception Tables.columns(chn3)
@test_throws Exception Tables.rows(chn3)
@test_throws Exception Tables.schema(chn3)

@testset "column access" begin
@test Tables.columnaccess(typeof(chn))
@test Tables.columns(chn) === chn
@test Tables.columnnames(chn) ==
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
@test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000]
@test Tables.getcolumn(chn, :chain) ==
[fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)]
@test Tables.getcolumn(chn, :a) == [
vec(chn[:, :a, 1])
vec(chn[:, :a, 2])
vec(chn[:, :a, 3])
vec(chn[:, :a, 4])
]
@test_throws Exception Tables.getcolumn(chn, :j)
@test Tables.getcolumn(chn, 1) == Tables.getcolumn(chn, :iteration)
@test Tables.getcolumn(chn, 2) == Tables.getcolumn(chn, :chain)
@test Tables.getcolumn(chn, 3) == Tables.getcolumn(chn, :a)
@test_throws Exception Tables.getcolumn(chn, :i)
@test_throws Exception Tables.getcolumn(chn, 11)
end

@testset "row access" begin
@test Tables.rowaccess(typeof(chn))
@test Tables.rows(chn) isa Tables.RowIterator
@test eltype(Tables.rows(chn)) <: Tables.AbstractRow
rows = collect(Tables.rows(chn))
@test eltype(rows) <: Tables.AbstractRow
@test size(rows) === (4000,)
for chainid in 1:4, iterid in 1:1000
row = rows[(chainid - 1) * 1000 + iterid]
@test Tables.columnnames(row) ==
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
@test Tables.getcolumn(row, 1) == iterid
@test Tables.getcolumn(row, 2) == chainid
@test Tables.getcolumn(row, 3) == chn[iterid, :a, chainid]
@test Tables.getcolumn(row, 10) == chn[iterid, :h, chainid]
@test Tables.getcolumn(row, :iteration) == iterid
@test Tables.getcolumn(row, :chain) == chainid
@test Tables.getcolumn(row, :a) == chn[iterid, :a, chainid]
@test Tables.getcolumn(row, :h) == chn[iterid, :h, chainid]
end
end

@testset "integration tests" begin
@test length(Tables.rowtable(chn)) == 4000
nt = Tables.rowtable(chn)[1]
@test nt ==
(; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1]
nt = Tables.rowtable(chn)[2]
@test nt ==
(; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2]
@test Tables.matrix(chn[:, :, 1])[:, 3:end] chn[:, :, 1].value
@test Tables.matrix(chn[:, :, 2])[:, 3:end] chn[:, :, 2].value
@test Tables.matrix(Tables.rowtable(chn)) == Tables.matrix(Tables.columntable(chn))
end

@testset "schema" begin
@test Tables.schema(chn) isa Tables.Schema
@test Tables.schema(chn).names ===
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
@test Tables.schema(chn).types === (
Int,
Int,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
Float64,
)
end

@testset "exceptions raised if reserved colname used" begin
val2 = rand(1000, 2, 4)
chn2 = Chains(val2, ["iteration", "a"])
@test_throws Exception Tables.columns(chn2)
@test_throws Exception Tables.rows(chn2)
@test_throws Exception Tables.schema(chn2)
chn3 = Chains(val2, ["chain", "a"])
@test_throws Exception Tables.columns(chn3)
@test_throws Exception Tables.rows(chn3)
@test_throws Exception Tables.schema(chn3)
end
end

@testset "TableTraits interface" begin
Expand All @@ -82,10 +114,10 @@ using DataFrames
@test nt ==
(; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...)

val = rand(1000, 2, 4)
chn2 = Chains(val, ["iteration", "a"])
val2 = rand(1000, 2, 4)
chn2 = Chains(val2, ["iteration", "a"])
@test_throws Exception IteratorInterfaceExtensions.getiterator(chn2)
chn3 = Chains(val, ["chain", "a"])
chn3 = Chains(val2, ["chain", "a"])
@test_throws Exception IteratorInterfaceExtensions.getiterator(chn3)
end

Expand All @@ -106,29 +138,54 @@ using DataFrames

@testset "Tables interface" begin
@test Tables.istable(typeof(cdf))
@test Tables.columnaccess(typeof(cdf))
@test Tables.columns(cdf) === cdf
@test Tables.columnnames(cdf) == keys(cdf.nt)
for (k, v) in pairs(cdf.nt)
@test Tables.getcolumn(cdf, k) == v

@testset "column access" begin
@test Tables.columnaccess(typeof(cdf))
@test Tables.columns(cdf) === cdf
@test Tables.columnnames(cdf) == keys(cdf.nt)
for (k, v) in pairs(cdf.nt)
@test Tables.getcolumn(cdf, k) == v
end
@test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1])
@test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2])
@test_throws Exception Tables.getcolumn(cdf, :blah)
@test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1)
end

@testset "row access" begin
@test Tables.rowaccess(typeof(cdf))
@test Tables.rows(cdf) isa Tables.RowIterator
@test eltype(Tables.rows(cdf)) <: Tables.AbstractRow
rows = collect(Tables.rows(cdf))
@test eltype(rows) <: Tables.AbstractRow
@test size(rows) === (2,)
@testset for i in 1:2
row = rows[i]
@test Tables.columnnames(row) == keys(cdf.nt)
for j in length(cdf.nt)
@test Tables.getcolumn(row, j) == cdf.nt[j][i]
@test Tables.getcolumn(row, keys(cdf.nt)[j]) == cdf.nt[j][i]
end
end
end

@testset "integration tests" begin
@test length(Tables.rowtable(cdf)) == length(cdf.nt[1])
@test Tables.columntable(cdf) == cdf.nt
nt = Tables.rowtable(cdf)[1]
@test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]
nt = Tables.rowtable(cdf)[2]
@test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]
@test Tables.matrix(Tables.rowtable(cdf)) == Tables.matrix(Tables.columntable(cdf))
end

@testset "schema" begin
@test Tables.schema(cdf) isa Tables.Schema
@test Tables.schema(cdf).names === keys(cdf.nt)
@test Tables.schema(cdf).types === eltype.(values(cdf.nt))
end
@test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1])
@test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2])
@test_throws Exception Tables.getcolumn(cdf, :blah)
@test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1)
@test Tables.rowaccess(typeof(cdf))
@test Tables.rows(cdf) === cdf
@test length(Tables.rowtable(cdf)) == length(cdf.nt[1])
@test Tables.columntable(cdf) == cdf.nt
nt = Tables.rowtable(cdf)[1]
@test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]
nt = Tables.rowtable(cdf)[2]
@test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...)
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]
@test Tables.schema(cdf) isa Tables.Schema
@test Tables.schema(cdf).names === keys(cdf.nt)
@test Tables.schema(cdf).types === eltype.(values(cdf.nt))
end

@testset "TableTraits interface" begin
Expand Down

4 comments on commit d60fb10

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Register Failed
@sethaxen, it looks like you are not a publicly listed member/owner in the parent organization (TuringLang).
If you are a member/owner, you will need to change your membership to public. See GitHub Help

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32044

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.7.2 -m "<description of version>" d60fb100f5d74fc35b1b5feb81492da467e61765
git push origin v4.7.2

Please sign in to comment.