From fd50e67e7b5b99356264adbb578290fb3dc67e09 Mon Sep 17 00:00:00 2001 From: pabloferz Date: Tue, 21 Jun 2016 19:03:31 +0200 Subject: [PATCH] Add default IsScalar trait --- base/array.jl | 9 +++++++++ base/dict.jl | 2 ++ base/essentials.jl | 1 + base/generator.jl | 5 ++++- base/iterator.jl | 17 +++++++++-------- base/reflection.jl | 1 + base/strings/basic.jl | 2 ++ base/strings/utf8proc.jl | 3 ++- base/tuple.jl | 2 ++ 9 files changed, 32 insertions(+), 10 deletions(-) diff --git a/base/array.jl b/base/array.jl index 0c568126ed17a..9417793959c46 100644 --- a/base/array.jl +++ b/base/array.jl @@ -206,6 +206,7 @@ promote_rule{T,n,S}(::Type{Array{T,n}}, ::Type{Array{S,n}}) = Array{promote_type # make a collection similar to `c` and appropriate for collecting `itr` _similar_for(c::AbstractArray, T, itr, ::SizeUnknown) = similar(c, T, 0) +_similar_for(c::AbstractArray, T, itr, ::IsScalar) = similar(c, T, ()) _similar_for(c::AbstractArray, T, itr, ::HasLength) = similar(c, T, Int(length(itr)::Integer)) _similar_for(c::AbstractArray, T, itr, ::HasShape) = similar(c, T, convert(Dims,size(itr))) _similar_for(c, T, itr, isz) = similar(c, T) @@ -217,6 +218,8 @@ Return an array of type `Array{element_type,1}` of all items in a collection. """ collect{T}(::Type{T}, itr) = collect(Generator(T, itr)) +_collect{T}(::Type{T}, itr, ::IsScalar) = (a = Array{T}(); a[] = itr; a) + """ collect(collection) @@ -226,6 +229,12 @@ collect(itr) = _collect(1:1 #= Array =#, itr, iteratoreltype(itr), iteratorsize( collect_similar(cont, itr) = _collect(cont, itr, iteratoreltype(itr), iteratorsize(itr)) +function _collect(cont, itr, ::IteratorEltype, isz::IsScalar) + a = _similar_for(cont, typeof(itr), itr, isz) + a[start(eachindex(a))] = itr + return a +end + _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = copy!(_similar_for(cont, eltype(itr), itr, isz), itr) diff --git a/base/dict.jl b/base/dict.jl index 0d5dbe321b642..a02deaefb2173 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -186,6 +186,7 @@ _tt1{A,B}(::Type{Pair{A,B}}) = A _tt2{A,B}(::Type{Pair{A,B}}) = B eltype{D}(::Type{KeyIterator{D}}) = _tt1(eltype(D)) eltype{D}(::Type{ValueIterator{D}}) = _tt2(eltype(D)) +iteratorsize{T<:Union{KeyIterator,ValueIterator}}(::Type{T}) = HasLength() start(v::Union{KeyIterator,ValueIterator}) = start(v.dict) done(v::Union{KeyIterator,ValueIterator}, state) = done(v.dict, state) @@ -272,6 +273,7 @@ function filter(f, d::Associative) end eltype{K,V}(::Type{Associative{K,V}}) = Pair{K,V} +iteratorsize{T<:Union{Associative,AbstractSet}}(::Type{T}) = HasLength() function isequal(l::Associative, r::Associative) l === r && return true diff --git a/base/essentials.jl b/base/essentials.jl index bdc95b133b7f4..5bc0924d7d518 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -173,6 +173,7 @@ done(v::SimpleVector,i) = (i > v.length) isempty(v::SimpleVector) = (v.length == 0) indices(v::SimpleVector, d) = d == 1 ? (1:length(v)) : (1:1) linearindices(v::SimpleVector) = indices(v, 1) +iteratorsize(::Type{SimpleVector}) = HasLength() function ==(v1::SimpleVector, v2::SimpleVector) length(v1)==length(v2) || return false diff --git a/base/generator.jl b/base/generator.jl index 48e00c749b84e..4df644266a7d6 100644 --- a/base/generator.jl +++ b/base/generator.jl @@ -29,16 +29,19 @@ end abstract IteratorSize immutable SizeUnknown <: IteratorSize end +immutable IsScalar <: IteratorSize end immutable HasLength <: IteratorSize end immutable HasShape <: IteratorSize end immutable IsInfinite <: IteratorSize end iteratorsize(x) = iteratorsize(typeof(x)) -iteratorsize(::Type) = HasLength() # HasLength is the default +iteratorsize(::Type) = IsScalar() # IsScalar is the default and_iteratorsize{T}(isz::T, ::T) = isz and_iteratorsize(::HasLength, ::HasShape) = HasLength() and_iteratorsize(::HasShape, ::HasLength) = HasLength() +and_iteratorsize(::IsScalar, ::Union{HasLength,HasShape}) = IsScalar() +and_iteratorsize(::Union{HasLength,HasShape}, ::IsScalar) = IsScalar() and_iteratorsize(a, b) = SizeUnknown() abstract IteratorEltype diff --git a/base/iterator.jl b/base/iterator.jl index 05359a4a14093..a24c0cf297089 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -38,6 +38,7 @@ iteratoreltype{I}(::Type{Enumerate{I}}) = iteratoreltype(I) abstract AbstractZipIterator zip_iteratorsize(a, b) = and_iteratorsize(a,b) # as `and_iteratorsize` but inherit `Union{HasLength,IsInfinite}` of the shorter iterator +zip_iteratorsize(::IsScalar, ::IsInfinite) = HasLength() zip_iteratorsize(::HasLength, ::IsInfinite) = HasLength() zip_iteratorsize(::HasShape, ::IsInfinite) = HasLength() zip_iteratorsize(a::IsInfinite, b) = zip_iteratorsize(b,a) @@ -308,20 +309,16 @@ iteratoreltype{O}(::Type{Repeated{O}}) = HasEltype() abstract AbstractProdIterator length(p::AbstractProdIterator) = prod(size(p)) -size(p::AbstractProdIterator) = _prod_size(p.a, p.b, iteratorsize(p.a), iteratorsize(p.b)) +size(p::AbstractProdIterator) = _prod_size(p.a, p.b) ndims(p::AbstractProdIterator) = length(size(p)) # generic methods to handle size of Prod* types +_prod_size(a, ::IsScalar) = (1,) _prod_size(a, ::HasShape) = size(a) _prod_size(a, ::HasLength) = (length(a), ) -_prod_size(a, A) = +_prod_size(a, ::IteratorSize) = throw(ArgumentError("Cannot compute size for object of type $(typeof(a))")) -_prod_size(a, b, ::HasLength, ::HasLength) = (length(a), length(b)) -_prod_size(a, b, ::HasLength, ::HasShape) = (length(a), size(b)...) -_prod_size(a, b, ::HasShape, ::HasLength) = (size(a)..., length(b)) -_prod_size(a, b, ::HasShape, ::HasShape) = (size(a)..., size(b)...) -_prod_size(a, b, A, B) = - throw(ArgumentError("Cannot construct size for objects of types $(typeof(a)) and $(typeof(b))")) +_prod_size(a, b) = (_prod_size(a, iteratorsize(a))..., _prod_size(b, iteratorsize(b))...) # one iterator immutable Prod1{I} <: AbstractProdIterator @@ -413,6 +410,9 @@ iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),it ((x[1][1],x[1][2]...), x[2]) end +prod_iteratorsize(::IsScalar, ::IsScalar) = IsScalar() +prod_iteratorsize(::IsScalar, isz::Union{HasLength,HasShape}) = isz +prod_iteratorsize(isz::Union{HasLength,HasShape}, ::IsScalar) = isz prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape() # products can have an infinite iterator prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite() @@ -534,6 +534,7 @@ type PartitionIterator{T} end eltype{T}(::Type{PartitionIterator{T}}) = Vector{eltype(T)} +iteratorsize{T<:PartitionIterator}(::Type{T}) = HasLength() function length(itr::PartitionIterator) l = length(itr.c) diff --git a/base/reflection.jl b/base/reflection.jl index a2f5889b42bca..a7f219815ba94 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -231,6 +231,7 @@ isempty(m::MethodList) = isempty(m.ms) start(m::MethodList) = start(m.ms) done(m::MethodList, s) = done(m.ms, s) next(m::MethodList, s) = next(m.ms, s) +iteratorsize(::Type{MethodList}) = HasLength() function MethodList(mt::MethodTable) ms = Method[] diff --git a/base/strings/basic.jl b/base/strings/basic.jl index 7308123929a07..4607b88592f46 100644 --- a/base/strings/basic.jl +++ b/base/strings/basic.jl @@ -2,6 +2,8 @@ ## core string functions ## +iteratorsize{T<:AbstractString}(::Type{T}) = HasLength() + endof(s::AbstractString) = error("you must implement endof(", typeof(s), ")") next(s::AbstractString, i::Int) = error("you must implement next(", typeof(s), ",Int)") next(s::DirectIndexString, i::Int) = (s[i],i+1) diff --git a/base/strings/utf8proc.jl b/base/strings/utf8proc.jl index 5b3ff7c430f17..1d0f1875b27b8 100644 --- a/base/strings/utf8proc.jl +++ b/base/strings/utf8proc.jl @@ -3,7 +3,7 @@ # Various Unicode functionality from the utf8proc library module UTF8proc -import Base: show, ==, hash, string, Symbol, isless, length, eltype, start, next, done, convert, isvalid, lowercase, uppercase +import Base: show, ==, hash, string, Symbol, isless, length, eltype, iteratorsize, start, next, done, convert, isvalid, lowercase, uppercase export isgraphemebreak @@ -190,6 +190,7 @@ end graphemes(s::AbstractString) = GraphemeIterator{typeof(s)}(s) eltype{S}(::Type{GraphemeIterator{S}}) = SubString{S} +iteratorsize{T<:GraphemeIterator}(::Type{T}) = Base.HasLength() function length(g::GraphemeIterator) c0 = Char(0x00ad) # soft hyphen (grapheme break always allowed after this) diff --git a/base/tuple.jl b/base/tuple.jl index 2e69bf217bff2..6af1ce617822f 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -12,6 +12,8 @@ getindex(t::Tuple, b::AbstractArray{Bool}) = getindex(t,find(b)) ## iterating ## +iteratorsize{T<:Tuple}(::Type{T}) = HasLength() + start(t::Tuple) = 1 done(t::Tuple, i::Int) = (length(t) < i) next(t::Tuple, i::Int) = (t[i], i+1)