Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataLoader with NamedTuple #1221

Merged
merged 6 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# v0.11
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].

# v0.10.5
Expand Down
24 changes: 14 additions & 10 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ end
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).

Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
The last dimension in each tensor is considered to be the observation dimension.

If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
Expand Down Expand Up @@ -57,6 +57,13 @@ Usage example:
# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)

# can use NamedTuple to name tensors
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
for datum in train_loader
@assert size(datum.images) == (10, 2)
@assert size(datum.labels) == (2,)
end
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
Expand Down Expand Up @@ -88,19 +95,16 @@ end

_nobs(data::AbstractArray) = size(data)[end]

function _nobs(data::Tuple)
function _nobs(data::Union{Tuple, NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
if !all(x -> _nobs(x) == n, Base.tail(data))
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end

function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end

_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the N from the type and drop the Val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, but why is that better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's largely for it to be cleaner, doing it like this doesn't seem to add any benefit and increases the complexity of the code while reading it

_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)

Base.eltype(d::DataLoader{D}) where D = D
Base.eltype(::DataLoader{D}) where D = D
20 changes: 20 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Y = [1:5;]

d = DataLoader(X, batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
Expand All @@ -11,20 +12,23 @@
@test batches[3] == X[:,5:5]

d = DataLoader(X, batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]

d = DataLoader((X,), batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)

d = DataLoader((X, Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
Expand All @@ -38,6 +42,22 @@
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]

# test with NamedTuple
d = DataLoader((x=X, y=Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == batches[1].x == X[:,1:2]
@test batches[1][2] == batches[1].y == Y[1:2]
@test batches[2][1] == batches[2].x == X[:,3:4]
@test batches[2][2] == batches[2].y == Y[3:4]
@test batches[3][1] == batches[3].x == X[:,5:5]
@test batches[3][2] == batches[3].y == Y[5:5]

# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)
Expand Down