From a385dccf67d4e8dd746b678d99d48ecae9f0e90b Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Thu, 4 Mar 2021 18:17:58 -0500 Subject: [PATCH] Early-convert BenchmarkGroup names to JSON-safe (#193) force strings early in definition, but allow some other JSON-types --- src/groups.jl | 49 ++++++++++++++++++------------ src/serialization.jl | 62 ++++++++++++++++++++++++++------------ test/GroupsTests.jl | 2 +- test/SerializationTests.jl | 14 ++++----- 4 files changed, 80 insertions(+), 47 deletions(-) diff --git a/src/groups.jl b/src/groups.jl index 6c3da19..99b1e5e 100644 --- a/src/groups.jl +++ b/src/groups.jl @@ -2,12 +2,19 @@ # BenchmarkGroup # ################## +const KeyTypes = Union{String,Int,Float64} +makekey(v::KeyTypes) = v +makekey(v::Real) = (v2 = Float64(v); v2 == v ? v2 : string(v)) +makekey(v::Integer) = typemin(Int) <= v <= typemax(Int) ? Int(v) : string(v) +makekey(v::Tuple) = (Any[i isa Tuple ? string(i) : makekey(i) for i in v]...,)::Tuple{Vararg{KeyTypes}} +makekey(v::Any) = string(v)::String + struct BenchmarkGroup tags::Vector{Any} data::Dict{Any,Any} end -BenchmarkGroup(tags::Vector, args::Pair...) = BenchmarkGroup(tags, Dict(args...)) +BenchmarkGroup(tags::Vector, args::Pair...) = BenchmarkGroup(tags, Dict{Any,Any}((makekey(k) => v for (k, v) in args))) BenchmarkGroup(args::Pair...) = BenchmarkGroup([], args...) function addgroup!(suite::BenchmarkGroup, id, args...) @@ -24,10 +31,14 @@ Base.copy(group::BenchmarkGroup) = BenchmarkGroup(copy(group.tags), copy(group.d Base.similar(group::BenchmarkGroup) = BenchmarkGroup(copy(group.tags), empty(group.data)) Base.isempty(group::BenchmarkGroup) = isempty(group.data) Base.length(group::BenchmarkGroup) = length(group.data) -Base.getindex(group::BenchmarkGroup, i...) = getindex(group.data, i...) -Base.setindex!(group::BenchmarkGroup, i...) = setindex!(group.data, i...) -Base.delete!(group::BenchmarkGroup, k...) = delete!(group.data, k...) -Base.haskey(group::BenchmarkGroup, k) = haskey(group.data, k) +Base.getindex(group::BenchmarkGroup, k) = getindex(group.data, makekey(k)) +Base.getindex(group::BenchmarkGroup, k...) = getindex(group.data, makekey(k)) +Base.setindex!(group::BenchmarkGroup, v, k) = setindex!(group.data, v, makekey(k)) +Base.setindex!(group::BenchmarkGroup, v, k...) = setindex!(group.data, v, makekey(k)) +Base.delete!(group::BenchmarkGroup, k) = delete!(group.data, makekey(k)) +Base.delete!(group::BenchmarkGroup, k...) = delete!(group.data, makekey(k)) +Base.haskey(group::BenchmarkGroup, k) = haskey(group.data, makekey(k)) +Base.haskey(group::BenchmarkGroup, k...) = haskey(group.data, makekey(k)) Base.keys(group::BenchmarkGroup) = keys(group.data) Base.values(group::BenchmarkGroup) = values(group.data) Base.iterate(group::BenchmarkGroup, i=1) = iterate(group.data, i) @@ -119,11 +130,11 @@ end # leaf iteration/indexing # #-------------------------# -leaves(group::BenchmarkGroup) = leaves!(Any[], Any[], group) +leaves(group::BenchmarkGroup) = leaves!([], [], group) function leaves!(results, parents, group::BenchmarkGroup) for (k, v) in group - keys = vcat(parents, k) + keys = Base.typed_vcat(Any, parents, k) if isa(v, BenchmarkGroup) leaves!(results, keys, v) else @@ -155,44 +166,44 @@ end # tagging # #---------# -struct TagFilter{P} - predicate::P +struct TagFilter + predicate end macro tagged(expr) - return esc(:(BenchmarkTools.TagFilter(tags -> $(tagpredicate!(expr))))) + return :(BenchmarkTools.TagFilter(tags -> $(tagpredicate!(expr)))) end -tagpredicate!(tag) = :(in($tag, tags)) +tagpredicate!(@nospecialize tag) = :(in(makekey($(esc(tag))), tags)) function tagpredicate!(sym::Symbol) - sym == :! && return sym sym == :ALL && return true - return :(in($sym, tags)) + return :(in(makekey($(esc(sym))), tags)) end # build the body of the tag predicate in place function tagpredicate!(expr::Expr) - expr.head == :quote && return :(in($expr, tags)) - for i in eachindex(expr.args) - expr.args[i] = tagpredicate!(expr.args[i]) + expr.head == :quote && return :(in(makekey($(esc(expr))), tags)) + for i in 1:length(expr.args) + f = (i == 1 && expr.head === :call ? esc : tagpredicate!) + expr.args[i] = f(expr.args[i]) end return expr end function Base.getindex(src::BenchmarkGroup, f::TagFilter) dest = similar(src) - loadtagged!(f, dest, src, src, Any[], src.tags) + loadtagged!(f, dest, src, src, [], src.tags) return dest end # normal union doesn't have the behavior we want # (e.g. union(["1"], "2") === ["1", '2']) -keyunion(args...) = unique(vcat(args...)) +keyunion(args...) = unique(Base.typed_vcat(Any, args...)) function tagunion(args...) unflattened = keyunion(args...) - result = Any[] + result = [] for i in unflattened if isa(i, Tuple) for j in i diff --git a/src/serialization.jl b/src/serialization.jl index 7affd72..15533c5 100644 --- a/src/serialization.jl +++ b/src/serialization.jl @@ -2,44 +2,66 @@ const VERSIONS = Dict("Julia" => string(VERSION), "BenchmarkTools" => string(BENCHMARKTOOLS_VERSION)) # TODO: Add any new types as they're added -const SUPPORTED_TYPES = [Benchmark, BenchmarkGroup, Parameters, TagFilter, Trial, - TrialEstimate, TrialJudgement, TrialRatio] +const SUPPORTED_TYPES = Dict{Symbol,Type}(Base.typename(x).name => x for x in [ + BenchmarkGroup, Parameters, TagFilter, Trial, + TrialEstimate, TrialJudgement, TrialRatio]) +# n.b. Benchmark type not included here, since it is gensym'd -for T in SUPPORTED_TYPES - @eval function JSON.lower(x::$T) - d = Dict{String,Any}() - for i = 1:nfields(x) - name = String(fieldname($T, i)) - field = getfield(x, i) - value = typeof(field) in SUPPORTED_TYPES ? JSON.lower(field) : field - push!(d, name => value) - end - [string(typeof(x)), d] +function JSON.lower(x::Union{values(SUPPORTED_TYPES)...}) + d = Dict{String,Any}() + T = typeof(x) + for i = 1:nfields(x) + name = String(fieldname(T, i)) + field = getfield(x, i) + ft = typeof(field) + value = ft <: get(SUPPORTED_TYPES, ft.name.name, Union{}) ? JSON.lower(field) : field + d[name] = value end + [string(typeof(x).name.name), d] end +# a minimal 'eval' function, mirroring KeyTypes, but being slightly more lenient +safeeval(@nospecialize x) = x +safeeval(x::QuoteNode) = x.value +function safeeval(x::Expr) + x.head === :quote && return x.args[1] + x.head === :inert && return x.args[1] + x.head === :tuple && return ((safeeval(a) for a in x.args)...,) + x +end function recover(x::Vector) length(x) == 2 || throw(ArgumentError("Expecting a vector of length 2")) typename = x[1]::String fields = x[2]::Dict - T = Core.eval(@__MODULE__, Meta.parse(typename))::Type + startswith(typename, "BenchmarkTools.") && (typename = typename[sizeof("BenchmarkTools.")+1:end]) + T = SUPPORTED_TYPES[Symbol(typename)] fc = fieldcount(T) xs = Vector{Any}(undef, fc) for i = 1:fc ft = fieldtype(T, i) fn = String(fieldname(T, i)) - xs[i] = if ft in SUPPORTED_TYPES - recover(fields[fn]) + if ft <: get(SUPPORTED_TYPES, ft.name.name, Union{}) + xsi = recover(fields[fn]) else - convert(ft, fields[fn]) + xsi = convert(ft, fields[fn]) end - if T == BenchmarkGroup && xs[i] isa Dict - for (k, v) in xs[i] + if T == BenchmarkGroup && xsi isa Dict + for (k, v) in copy(xsi) + k = k::String + if startswith(k, "(") || startswith(k, ":") + kt = Meta.parse(k, raise=false) + if !(kt isa Expr && kt.head === :error) + delete!(xsi, k) + k = safeeval(kt) + xsi[k] = v + end + end if v isa Vector && length(v) == 2 && v[1] isa String - xs[i][k] = recover(v) + xsi[k] = recover(v) end end end + xs[i] = xsi end T(xs...) end @@ -73,7 +95,7 @@ function save(io::IO, args...) "The name will be ignored and the object will be serialized " * "in the order it appears in the input.") continue - elseif !any(T->arg isa T, SUPPORTED_TYPES) + elseif !(arg isa get(SUPPORTED_TYPES, typeof(arg).name.name, Union{})) throw(ArgumentError("Only BenchmarkTools types can be serialized.")) end push!(goodargs, arg) diff --git a/test/GroupsTests.jl b/test/GroupsTests.jl index cc8fa00..6f10962 100644 --- a/test/GroupsTests.jl +++ b/test/GroupsTests.jl @@ -216,7 +216,7 @@ gnest = BenchmarkGroup(["1"], 10 => BenchmarkGroup(["3"]), 11 => BenchmarkGroup())) -@test sort(leaves(gnest), by=string) == +@test sort(leaves(gnest), by=string) == Any[(Any["2",1],1), (Any["a","a"],:a), (Any["a",(11,"b")],:b), (Any[4,5],6), (Any[7],8)] @test gnest[@tagged 11 || 10] == BenchmarkGroup(["1"], diff --git a/test/SerializationTests.jl b/test/SerializationTests.jl index 63eb1e2..8790076 100644 --- a/test/SerializationTests.jl +++ b/test/SerializationTests.jl @@ -3,7 +3,7 @@ module SerializationTests using BenchmarkTools using Test -eq(x::T, y::T) where {T<:Union{BenchmarkTools.SUPPORTED_TYPES...}} = +eq(x::T, y::T) where {T<:Union{values(BenchmarkTools.SUPPORTED_TYPES)...}} = all(i->eq(getfield(x, i), getfield(y, i)), 1:fieldcount(T)) eq(x::T, y::T) where {T} = isapprox(x, y) @@ -25,13 +25,13 @@ end withtempdir() do tmp = joinpath(pwd(), "tmp.json") - BenchmarkTools.save(tmp, b, bb) + BenchmarkTools.save(tmp, b.params, bb) @test isfile(tmp) results = BenchmarkTools.load(tmp) @test results isa Vector{Any} @test length(results) == 2 - @test eq(results[1], b) + @test eq(results[1], b.params) @test eq(results[2], bb) end @@ -56,18 +56,18 @@ end tune!(b) bb = run(b) - @test_throws ArgumentError BenchmarkTools.save("x.jld", b) - @test_throws ArgumentError BenchmarkTools.save("x.txt", b) + @test_throws ArgumentError BenchmarkTools.save("x.jld", b.params) + @test_throws ArgumentError BenchmarkTools.save("x.txt", b.params) @test_throws ArgumentError BenchmarkTools.save("x.json") @test_throws ArgumentError BenchmarkTools.save("x.json", 1) withtempdir() do tmp = joinpath(pwd(), "tmp.json") - @test_logs (:warn, r"Naming variables") BenchmarkTools.save(tmp, "b", b) + @test_logs (:warn, r"Naming variables") BenchmarkTools.save(tmp, "b", b.params) @test isfile(tmp) results = BenchmarkTools.load(tmp) @test length(results) == 1 - @test eq(results[1], b) + @test eq(results[1], b.params) end @test_throws ArgumentError BenchmarkTools.load("x.jld")