From b29e0fea88cb9b410722850352714e01eba9d40f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 18 Aug 2023 10:20:42 -0400 Subject: [PATCH] Elegance through simplicity --- base/scopedvalues.jl | 64 +++++++++++++------------------------------- test/scopedvalues.jl | 7 +++++ 2 files changed, 25 insertions(+), 46 deletions(-) diff --git a/base/scopedvalues.jl b/base/scopedvalues.jl index 3160c44e00637..fb28652937e1e 100644 --- a/base/scopedvalues.jl +++ b/base/scopedvalues.jl @@ -55,64 +55,36 @@ import Base: ImmutableDict # the value portion is read-only/write-once, but the cache version # would need a lock which makes ImmutableDict incredibly attractive. # We could also use task-local-storage, but that added about 12ns. -# - Values are GC'd when scopes become unreachable, one could use +# - Values are GC'd when scopes become unreachable, one could use # a WeakKeyDict to also ensure that values get GC'd when ScopedValues # become unreachable. +# - Scopes are an inline implementation of an ImmutableDict, if we wanted +# be really fance we could use a CTrie or HAMT. -mutable struct Scope +mutable struct Scope{T} const parent::Union{Nothing, Scope} - @atomic values::ImmutableDict{ScopedValue, Any} + const key::ScopedValue{T} + const value::T end -Scope(parent) = Scope(parent, ImmutableDict{ScopedValue, Any}()) +Scope(parent, key::ScopedValue{T}, value) where T = + Scope(parent, key, convert(T, value)) + current_scope() = current_task().scope::Union{Nothing, Scope} function Base.show(io::IO, ::Scope) print(io, Scope) end -# VC: I find it rather useful to have one function to use for both -# haskey and get. -@inline function get(dict::ImmutableDict, key, ::Type{T}) where T - while isdefined(dict, :parent) - isequal(dict.key, key) && return Some(dict.value::T) - dict = dict.parent - end - return nothing -end - function Base.getindex(var::ScopedValue{T})::T where T scope = current_scope() - if scope === nothing - return var.initial_value - end - cs = scope - - val = var.initial_value while scope !== nothing - values = @atomic :acquire scope.values - _val = get(values, var, T) - if _val !== nothing - val = something(_val) - break + if scope.key === var + return scope.value::T end scope = scope.parent end - - if cs != scope - # found the value in an upper scope, copy it down to the cache. - # We are using the same dict for both cache and values. - # One can split these and potentially use `ImmutableDict` only for values - # and a Dict with SpinLock for the cache. - success = false - old = @atomic :acquire cs.values - while !success - new = ImmutableDict(old, var => val) - old, success = @atomicreplace :acquire_release :acquire cs.values old => new - end - end - - return val -end + return var.initial_value +en function Base.show(io::IO, var::ScopedValue) print(io, ScopedValue) @@ -129,13 +101,13 @@ Execute `f` in a new scope with `var` set to `val`. """ function scoped(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...) @nospecialize - values = ImmutableDict{ScopedValue, Any}(pair...) - for pair in rest - values = ImmutableDict{ScopedValue, Any}(values, pair...) - end ct = Base.current_task() current_scope = ct.scope::Union{Nothing, Scope} - ct.scope = Scope(current_scope, values) + scope = Scope(current_scope, pair...) + for pair in rest + scope = Scope(scope, pair...) + end + ct.scope = scope try return f() finally diff --git a/test/scopedvalues.jl b/test/scopedvalues.jl index bbb07286d30dc..dde9300173e37 100644 --- a/test/scopedvalues.jl +++ b/test/scopedvalues.jl @@ -35,6 +35,13 @@ const svar_float = ScopedValue(1.0) end end +emptyf() = nothing + +@testset "conversion" begin + scoped(emptyf, gvar_float=>2) + @test_throws MethodError scoped(emptyf, gvar_float=>"hello") +end + import Base.Threads: @spawn @testset "tasks" begin @test fetch(@spawn begin