Skip to content

Commit

Permalink
Use Val(x) and f(::Val{x})
Browse files Browse the repository at this point in the history
Replaces `f(Val{x})` on call sites with `f(Val(x))`, using a new
`@pure` function `Val(x) = Val{x}()`. This simplifies the method
definitions from `f(::Type{Val{x}}) where x` to `f(::Val{x}) where x`.

This form also has the advantage that multiple singleton instances
can be put in a tuple and inference will work (similarly with
multiple-return functions).
  • Loading branch information
andyferris authored and Andy Ferris committed Jul 5, 2017
1 parent bf83397 commit f0a91f7
Show file tree
Hide file tree
Showing 44 changed files with 261 additions and 216 deletions.
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ Library improvements
* `@test isequal(x, y)` and `@test isapprox(x, y)` now prints an evaluated expression when
the test fails ([#22296]).

* Uses of `Val{c}` in `Base` has been replaced with `Val{c}()`, which is now easily
accessible via the `@pure` constructor `Val(c)`. Functions are defined as
`f(::Val{c}) = ...` and called by `f(Val(c))`. Notable affected functions include:
`ntuple`, `Base.literal_pow`, `sqrtm`, `lufact`, `lufact!`, `qrfact`, `qrfact!`,
`cholfact`, `cholfact!`, `_broadcast!`, `reshape`, `cat` and `cat_t`.

Compiler/Runtime improvements
-----------------------------

Expand Down
46 changes: 23 additions & 23 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ julia> size(A,3,2)
"""
size(t::AbstractArray{T,N}, d) where {T,N} = d <= N ? size(t)[d] : 1
size(x, d1::Integer, d2::Integer, dx::Vararg{Integer, N}) where {N} =
(size(x, d1), size(x, d2), ntuple(k->size(x, dx[k]), Val{N})...)
(size(x, d1), size(x, d2), ntuple(k->size(x, dx[k]), Val(N))...)

"""
indices(A, d)
Expand Down Expand Up @@ -954,13 +954,13 @@ function _getindex(::IndexCartesian, A::AbstractArray{T,N}, I::Vararg{Int, N}) w
getindex(A, I...)
end
_to_subscript_indices(A::AbstractArray, i::Int) = (@_inline_meta; _unsafe_ind2sub(A, i))
_to_subscript_indices(A::AbstractArray{T,N}) where {T,N} = (@_inline_meta; fill_to_length((), 1, Val{N})) # TODO: DEPRECATE FOR #14770
_to_subscript_indices(A::AbstractArray{T,N}) where {T,N} = (@_inline_meta; fill_to_length((), 1, Val(N))) # TODO: DEPRECATE FOR #14770
_to_subscript_indices(A::AbstractArray{T,0}) where {T} = () # TODO: REMOVE FOR #14770
_to_subscript_indices(A::AbstractArray{T,0}, i::Int) where {T} = () # TODO: REMOVE FOR #14770
_to_subscript_indices(A::AbstractArray{T,0}, I::Int...) where {T} = () # TODO: DEPRECATE FOR #14770
function _to_subscript_indices(A::AbstractArray{T,N}, I::Int...) where {T,N} # TODO: DEPRECATE FOR #14770
@_inline_meta
J, Jrem = IteratorsMD.split(I, Val{N})
J, Jrem = IteratorsMD.split(I, Val(N))
_to_subscript_indices(A, J, Jrem)
end
_to_subscript_indices(A::AbstractArray, J::Tuple, Jrem::Tuple{}) =
Expand Down Expand Up @@ -1203,7 +1203,7 @@ cat_shape(dims, shape::Tuple) = shape

_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape
_cshp(ndim::Int, dims, ::Tuple{}, ::Tuple{}) = ntuple(b -> 1, Val{length(dims)})
_cshp(ndim::Int, dims, ::Tuple{}, ::Tuple{}) = ntuple(b -> 1, Val(length(dims)))
@inline _cshp(ndim::Int, dims, shape, ::Tuple{}) =
(shape[1] + dims[1], _cshp(ndim + 1, tail(dims), tail(shape), ())...)
@inline _cshp(ndim::Int, dims, ::Tuple{}, nshape) =
Expand All @@ -1226,7 +1226,7 @@ end
_cs(d, a, b) = (a == b ? a : throw(DimensionMismatch(
"mismatch in dimension $d (expected $a got $b)")))

dims2cat(::Type{Val{n}}) where {n} = ntuple(i -> (i == n), Val{n})
dims2cat(::Val{n}) where {n} = ntuple(i -> (i == n), Val(n))
dims2cat(dims) = ntuple(i -> (i in dims), maximum(dims))

cat(dims, X...) = cat_t(dims, promote_eltypeof(X...), X...)
Expand Down Expand Up @@ -1290,7 +1290,7 @@ julia> vcat(c...)
4 5 6
```
"""
vcat(X...) = cat(Val{1}, X...)
vcat(X...) = cat(Val(1), X...)
"""
hcat(A...)
Expand Down Expand Up @@ -1331,28 +1331,28 @@ julia> hcat(c...)
3 6
```
"""
hcat(X...) = cat(Val{2}, X...)
hcat(X...) = cat(Val(2), X...)

typed_vcat(T::Type, X...) = cat_t(Val{1}, T, X...)
typed_hcat(T::Type, X...) = cat_t(Val{2}, T, X...)
typed_vcat(T::Type, X...) = cat_t(Val(1), T, X...)
typed_hcat(T::Type, X...) = cat_t(Val(2), T, X...)

cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(catdims, T, A...)

# The specializations for 1 and 2 inputs are important
# especially when running with --inline=no, see #11158
vcat(A::AbstractArray) = cat(Val{1}, A)
vcat(A::AbstractArray, B::AbstractArray) = cat(Val{1}, A, B)
vcat(A::AbstractArray...) = cat(Val{1}, A...)
hcat(A::AbstractArray) = cat(Val{2}, A)
hcat(A::AbstractArray, B::AbstractArray) = cat(Val{2}, A, B)
hcat(A::AbstractArray...) = cat(Val{2}, A...)

typed_vcat(T::Type, A::AbstractArray) = cat_t(Val{1}, T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val{1}, T, A, B)
typed_vcat(T::Type, A::AbstractArray...) = cat_t(Val{1}, T, A...)
typed_hcat(T::Type, A::AbstractArray) = cat_t(Val{2}, T, A)
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val{2}, T, A, B)
typed_hcat(T::Type, A::AbstractArray...) = cat_t(Val{2}, T, A...)
vcat(A::AbstractArray) = cat(Val(1), A)
vcat(A::AbstractArray, B::AbstractArray) = cat(Val(1), A, B)
vcat(A::AbstractArray...) = cat(Val(1), A...)
hcat(A::AbstractArray) = cat(Val(2), A)
hcat(A::AbstractArray, B::AbstractArray) = cat(Val(2), A, B)
hcat(A::AbstractArray...) = cat(Val(2), A...)

typed_vcat(T::Type, A::AbstractArray) = cat_t(Val(1), T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val(1), T, A, B)
typed_vcat(T::Type, A::AbstractArray...) = cat_t(Val(1), T, A...)
typed_hcat(T::Type, A::AbstractArray) = cat_t(Val(2), T, A)
typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(Val(2), T, A, B)
typed_hcat(T::Type, A::AbstractArray...) = cat_t(Val(2), T, A...)

# 2d horizontal and vertical concatenation

Expand Down Expand Up @@ -1721,7 +1721,7 @@ _sub2ind_vec(i) = ()

function ind2sub(inds::Union{DimsInteger{N},Indices{N}}, ind::AbstractVector{<:Integer}) where N
M = length(ind)
t = ntuple(n->similar(ind),Val{N})
t = ntuple(n->similar(ind),Val(N))
for (i,idx) in enumerate(IndexLinear(), ind)
sub = ind2sub(inds, idx)
for j = 1:N
Expand Down
4 changes: 2 additions & 2 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ julia> repeat([1 2; 3 4], inner=(2, 1), outer=(1, 3))
```
"""
function repeat(A::AbstractArray;
inner=ntuple(n->1, Val{ndims(A)}),
outer=ntuple(n->1, Val{ndims(A)}))
inner=ntuple(n->1, Val(ndims(A))),
outer=ntuple(n->1, Val(ndims(A))))
return _repeat(A, rep_kw2tup(inner), rep_kw2tup(outer))
end

Expand Down
2 changes: 1 addition & 1 deletion base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
size(a::Array, d) = arraysize(a, d)
size(a::Vector) = (arraysize(a,1),)
size(a::Matrix) = (arraysize(a,1), arraysize(a,2))
size(a::Array{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val{N}))
size(a::Array{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val(N)))

asize_from(a::Array, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...)

Expand Down
2 changes: 1 addition & 1 deletion base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ reshape(B::BitArray, dims::Tuple{Vararg{Int}}) = _bitreshape(B, dims)
function _bitreshape(B::BitArray, dims::NTuple{N,Int}) where N
prod(dims) == length(B) ||
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(length(B))"))
Br = BitArray{N}(ntuple(i->0,Val{N})...)
Br = BitArray{N}(ntuple(i->0,Val(N))...)
Br.chunks = B.chunks
Br.len = prod(dims)
N != 1 && (Br.dims = dims)
Expand Down
20 changes: 10 additions & 10 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
## Broadcasting core
# nargs encodes the number of As arguments (which matches the number
# of keeps). The first two type parameters are to ensure specialization.
@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Type{Val{N}}, iter) where {K,ID,AT,BT,N}
@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N}
nargs = N + 1
quote
$(Expr(:meta, :inline))
Expand All @@ -157,7 +157,7 @@ end

# For BitArray outputs, we cache the result in a "small" Vector{Bool},
# and then copy in chunks into the output
@generated function _broadcast!(f, B::BitArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Type{Val{N}}, iter) where {K,ID,AT,BT,N}
@generated function _broadcast!(f, B::BitArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N}
nargs = N + 1
quote
$(Expr(:meta, :inline))
Expand Down Expand Up @@ -207,12 +207,12 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
@boundscheck check_broadcast_indices(shape, A, Bs...)
keeps, Idefaults = map_newindexer(shape, A, Bs)
iter = CartesianRange(shape)
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val{N}, iter)
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val(N), iter)
return C
end

# broadcast with computed element type
@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter, st, count) where {K,ID,AT,nargs}
@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Val{nargs}, iter, st, count) where {K,ID,AT,nargs}
quote
$(Expr(:meta, :noinline))
# destructure the keeps and As tuples
Expand All @@ -238,7 +238,7 @@ end
new[II] = B[II]
end
new[I] = V
return _broadcast!(f, new, keeps, Idefaults, As, Val{nargs}, iter, st, count+1)
return _broadcast!(f, new, keeps, Idefaults, As, Val(nargs), iter, st, count+1)
end
count += 1
end
Expand All @@ -259,12 +259,12 @@ function broadcast_t(f, ::Type{Any}, shape, iter, As...)
B = similar(Array{typeof(val)}, shape)
end
B[I] = val
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1)
return _broadcast!(f, B, keeps, Idefaults, As, Val(nargs), iter, st, 1)
end
@inline function broadcast_t(f, T, shape, iter, A, Bs::Vararg{Any,N}) where N
C = similar(Array{T}, shape)
keeps, Idefaults = map_newindexer(shape, A, Bs)
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val{N}, iter)
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val(N), iter)
return C
end

Expand All @@ -275,7 +275,7 @@ end
@inline function broadcast_t(f, ::Type{Bool}, shape, iter, A, Bs::Vararg{Any,N}) where N
C = similar(BitArray, shape)
keeps, Idefaults = map_newindexer(shape, A, Bs)
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val{N}, iter)
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val(N), iter)
return C
end

Expand Down Expand Up @@ -335,9 +335,9 @@ end
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val{N})
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N))
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val{N})
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N))
first_tuple(A::Tuple, Bs...) = A
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
tuplebroadcast_getargs(::Tuple{}, k) = ()
Expand Down
32 changes: 28 additions & 4 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1354,12 +1354,12 @@ end
@deprecate versioninfo(io::IO, verbose::Bool) versioninfo(io, verbose=verbose)

# PR #22188
@deprecate cholfact!(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}}) cholfact!(Hermitian(A, uplo), Val{false})
@deprecate cholfact!(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}}) cholfact!(Hermitian(A, uplo), Val(false))
@deprecate cholfact!(A::StridedMatrix, uplo::Symbol) cholfact!(Hermitian(A, uplo))
@deprecate cholfact(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}}) cholfact(Hermitian(A, uplo), Val{false})
@deprecate cholfact(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}}) cholfact(Hermitian(A, uplo), Val(false))
@deprecate cholfact(A::StridedMatrix, uplo::Symbol) cholfact(Hermitian(A, uplo))
@deprecate cholfact!(A::StridedMatrix, uplo::Symbol, ::Type{Val{true}}; tol = 0.0) cholfact!(Hermitian(A, uplo), Val{true}, tol = tol)
@deprecate cholfact(A::StridedMatrix, uplo::Symbol, ::Type{Val{true}}; tol = 0.0) cholfact(Hermitian(A, uplo), Val{true}, tol = tol)
@deprecate cholfact!(A::StridedMatrix, uplo::Symbol, ::Type{Val{true}}; tol = 0.0) cholfact!(Hermitian(A, uplo), Val(true), tol = tol)
@deprecate cholfact(A::StridedMatrix, uplo::Symbol, ::Type{Val{true}}; tol = 0.0) cholfact(Hermitian(A, uplo), Val(true), tol = tol)

# PR #22245
@deprecate isposdef(A::AbstractMatrix, UL::Symbol) isposdef(Hermitian(A, UL))
Expand Down Expand Up @@ -1519,6 +1519,30 @@ function replace(s::AbstractString, pat, f, n::Integer)
end
end

# PR #22475
@deprecate ntuple{N}(f, ::Type{Val{N}}) ntuple(f, Val(N))
@deprecate fill_to_length{N}(t, val, ::Type{Val{N}}) fill_to_length(t, val, Val(N)) false
@deprecate literal_pow{N}(a, b, ::Type{Val{N}}) literal_pow(a, b, Val(N)) false
@eval IteratorsMD @deprecate split{n}(t, V::Type{Val{n}}) split(t, Val(n)) false
@deprecate sqrtm{T,realmatrix}(A::UpperTriangular{T},::Type{Val{realmatrix}}) sqrtm(A, Val(realmatrix))
@deprecate lufact(A::AbstractMatrix, ::Type{Val{false}}) lufact(A, Val(false))
@deprecate lufact(A::AbstractMatrix, ::Type{Val{true}}) lufact(A, Val(true))
@deprecate lufact!(A::AbstractMatrix, ::Type{Val{false}}) lufact!(A, Val(false))
@deprecate lufact!(A::AbstractMatrix, ::Type{Val{true}}) lufact!(A, Val(true))
@deprecate qrfact(A::AbstractMatrix, ::Type{Val{false}}) qrfact(A, Val(false))
@deprecate qrfact(A::AbstractMatrix, ::Type{Val{true}}) qrfact(A, Val(true))
@deprecate qrfact!(A::AbstractMatrix, ::Type{Val{false}}) qrfact!(A, Val(false))
@deprecate qrfact!(A::AbstractMatrix, ::Type{Val{true}}) qrfact!(A, Val(true))
@deprecate cholfact(A::AbstractMatrix, ::Type{Val{false}}) cholfact(A, Val(false))
@deprecate cholfact(A::AbstractMatrix, ::Type{Val{true}}; tol = 0.0) cholfact(A, Val(true); tol = tol)
@deprecate cholfact!(A::AbstractMatrix, ::Type{Val{false}}) cholfact!(A, Val(false))
@deprecate cholfact!(A::AbstractMatrix, ::Type{Val{true}}; tol = 0.0) cholfact!(A, Val(true); tol = tol)
@deprecate cat{N}(::Type{Val{N}}, A::AbstractArray...) cat(Val(N), A...)
@deprecate cat{N}(::Type{Val{N}}, A::SparseArrays._SparseConcatGroup...) cat(Val(N), A...)
@deprecate cat{N}(::Type{Val{N}}, A::SparseArrays._DenseConcatGroup...) cat(Val(N), A...)
@deprecate cat_t{N,T}(::Type{Val{N}}, ::Type{T}, A, B) cat_t(Val(N), T, A, B) false
@deprecate reshape{N}(A::AbstractArray, ::Type{Val{N}}) reshape(A, Val(N))

# END 0.7 deprecations

# BEGIN 1.0 deprecations
Expand Down
29 changes: 22 additions & 7 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,18 +313,33 @@ struct Colon
end
const (:) = Colon()

# For passing constants through type inference
"""
Val{c}
Val(c)
Create a "value type" out of `c`, which must be an `isbits` value. The intent of this
construct is to be able to dispatch on constants, e.g., `f(Val{false})` allows you to
dispatch directly (at compile-time) to an implementation `f(::Type{Val{false}})`, without
having to test the boolean value at runtime.
Return `Val{c}()`, which contains no run-time data. Types like this can be used to
pass the information between functions through the value `c`, which must be an `isbits`
value. The intent of this construct is to be able to dispatch on constants directly (at
compile time) without having to test the value of the constant at run time.
# Examples
```jldoctest
julia> f(::Val{true}) = "Good"
f (generic function with 1 method)
julia> f(::Val{false}) = "Bad"
f (generic function with 2 methods)
julia> f(Val(true))
"Good"
```
"""
struct Val{T}
struct Val{x}
end

Val(x) = (@_pure_meta; Val{x}())

show(io::IO, ::Val{x}) where {x} = print(io, "Val($x)")

# used by interpolating quote and some other things in the front end
function vector_any(xs::ANY...)
n = length(xs)
Expand Down
12 changes: 6 additions & 6 deletions base/intfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ end
^(x::Number, p::Integer) = power_by_squaring(x,p)
^(x, p::Integer) = power_by_squaring(x,p)

# x^p for any literal integer p is lowered to Base.literal_pow(^, x, Val{p})
# x^p for any literal integer p is lowered to Base.literal_pow(^, x, Val(p))
# to enable compile-time optimizations specialized to p.
# However, we still need a fallback that calls the function ^ which may either
# mean Base.^ or something else, depending on context.
# We mark these @inline since if the target is marked @inline,
# we want to make sure that gets propagated,
# even if it is over the inlining threshold.
@inline literal_pow(f, x, ::Type{Val{p}}) where {p} = f(x,p)
@inline literal_pow(f, x, ::Val{p}) where {p} = f(x,p)

# Restrict inlining to hardware-supported arithmetic types, which
# are fast enough to benefit from inlining.
Expand All @@ -216,10 +216,10 @@ const HWNumber = Union{HWReal, Complex{<:HWReal}, Rational{<:HWReal}}
# numeric types. In terms of Val we can do it much more simply.
# (The first argument prevents unexpected behavior if a function ^
# is defined that is not equal to Base.^)
@inline literal_pow(::typeof(^), x::HWNumber, ::Type{Val{0}}) = one(x)
@inline literal_pow(::typeof(^), x::HWNumber, ::Type{Val{1}}) = x
@inline literal_pow(::typeof(^), x::HWNumber, ::Type{Val{2}}) = x*x
@inline literal_pow(::typeof(^), x::HWNumber, ::Type{Val{3}}) = x*x*x
@inline literal_pow(::typeof(^), x::HWNumber, ::Val{0}) = one(x)
@inline literal_pow(::typeof(^), x::HWNumber, ::Val{1}) = x
@inline literal_pow(::typeof(^), x::HWNumber, ::Val{2}) = x*x
@inline literal_pow(::typeof(^), x::HWNumber, ::Val{3}) = x*x*x

# b^p mod m

Expand Down
2 changes: 1 addition & 1 deletion base/linalg/arnoldi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ function A_mul_B!(y::StridedVector{T}, A::AtA_or_AAt{T}, x::StridedVector{T}) wh
return A_mul_B!(y, A.A, A.buffer)
end
end
size(A::AtA_or_AAt) = ntuple(i -> min(size(A.A)...), Val{2})
size(A::AtA_or_AAt) = ntuple(i -> min(size(A.A)...), Val(2))
ishermitian(s::AtA_or_AAt) = true


Expand Down
Loading

0 comments on commit f0a91f7

Please sign in to comment.