Skip to content

Commit

Permalink
Early-convert BenchmarkGroup names to JSON-safe (#193)
Browse files Browse the repository at this point in the history
force strings early in definition, but allow some other JSON-types
  • Loading branch information
vtjnash authored Mar 4, 2021
1 parent 2e2e6ef commit a385dcc
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 47 deletions.
49 changes: 30 additions & 19 deletions src/groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
# BenchmarkGroup #
##################

const KeyTypes = Union{String,Int,Float64}
makekey(v::KeyTypes) = v
makekey(v::Real) = (v2 = Float64(v); v2 == v ? v2 : string(v))
makekey(v::Integer) = typemin(Int) <= v <= typemax(Int) ? Int(v) : string(v)
makekey(v::Tuple) = (Any[i isa Tuple ? string(i) : makekey(i) for i in v]...,)::Tuple{Vararg{KeyTypes}}
makekey(v::Any) = string(v)::String

struct BenchmarkGroup
tags::Vector{Any}
data::Dict{Any,Any}
end

BenchmarkGroup(tags::Vector, args::Pair...) = BenchmarkGroup(tags, Dict(args...))
BenchmarkGroup(tags::Vector, args::Pair...) = BenchmarkGroup(tags, Dict{Any,Any}((makekey(k) => v for (k, v) in args)))
BenchmarkGroup(args::Pair...) = BenchmarkGroup([], args...)

function addgroup!(suite::BenchmarkGroup, id, args...)
Expand All @@ -24,10 +31,14 @@ Base.copy(group::BenchmarkGroup) = BenchmarkGroup(copy(group.tags), copy(group.d
Base.similar(group::BenchmarkGroup) = BenchmarkGroup(copy(group.tags), empty(group.data))
Base.isempty(group::BenchmarkGroup) = isempty(group.data)
Base.length(group::BenchmarkGroup) = length(group.data)
Base.getindex(group::BenchmarkGroup, i...) = getindex(group.data, i...)
Base.setindex!(group::BenchmarkGroup, i...) = setindex!(group.data, i...)
Base.delete!(group::BenchmarkGroup, k...) = delete!(group.data, k...)
Base.haskey(group::BenchmarkGroup, k) = haskey(group.data, k)
Base.getindex(group::BenchmarkGroup, k) = getindex(group.data, makekey(k))
Base.getindex(group::BenchmarkGroup, k...) = getindex(group.data, makekey(k))
Base.setindex!(group::BenchmarkGroup, v, k) = setindex!(group.data, v, makekey(k))
Base.setindex!(group::BenchmarkGroup, v, k...) = setindex!(group.data, v, makekey(k))
Base.delete!(group::BenchmarkGroup, k) = delete!(group.data, makekey(k))
Base.delete!(group::BenchmarkGroup, k...) = delete!(group.data, makekey(k))
Base.haskey(group::BenchmarkGroup, k) = haskey(group.data, makekey(k))
Base.haskey(group::BenchmarkGroup, k...) = haskey(group.data, makekey(k))
Base.keys(group::BenchmarkGroup) = keys(group.data)
Base.values(group::BenchmarkGroup) = values(group.data)
Base.iterate(group::BenchmarkGroup, i=1) = iterate(group.data, i)
Expand Down Expand Up @@ -119,11 +130,11 @@ end
# leaf iteration/indexing #
#-------------------------#

leaves(group::BenchmarkGroup) = leaves!(Any[], Any[], group)
leaves(group::BenchmarkGroup) = leaves!([], [], group)

function leaves!(results, parents, group::BenchmarkGroup)
for (k, v) in group
keys = vcat(parents, k)
keys = Base.typed_vcat(Any, parents, k)
if isa(v, BenchmarkGroup)
leaves!(results, keys, v)
else
Expand Down Expand Up @@ -155,44 +166,44 @@ end
# tagging #
#---------#

struct TagFilter{P}
predicate::P
struct TagFilter
predicate
end

macro tagged(expr)
return esc(:(BenchmarkTools.TagFilter(tags -> $(tagpredicate!(expr)))))
return :(BenchmarkTools.TagFilter(tags -> $(tagpredicate!(expr))))
end

tagpredicate!(tag) = :(in($tag, tags))
tagpredicate!(@nospecialize tag) = :(in(makekey($(esc(tag))), tags))

function tagpredicate!(sym::Symbol)
sym == :! && return sym
sym == :ALL && return true
return :(in($sym, tags))
return :(in(makekey($(esc(sym))), tags))
end

# build the body of the tag predicate in place
function tagpredicate!(expr::Expr)
expr.head == :quote && return :(in($expr, tags))
for i in eachindex(expr.args)
expr.args[i] = tagpredicate!(expr.args[i])
expr.head == :quote && return :(in(makekey($(esc(expr))), tags))
for i in 1:length(expr.args)
f = (i == 1 && expr.head === :call ? esc : tagpredicate!)
expr.args[i] = f(expr.args[i])
end
return expr
end

function Base.getindex(src::BenchmarkGroup, f::TagFilter)
dest = similar(src)
loadtagged!(f, dest, src, src, Any[], src.tags)
loadtagged!(f, dest, src, src, [], src.tags)
return dest
end

# normal union doesn't have the behavior we want
# (e.g. union(["1"], "2") === ["1", '2'])
keyunion(args...) = unique(vcat(args...))
keyunion(args...) = unique(Base.typed_vcat(Any, args...))

function tagunion(args...)
unflattened = keyunion(args...)
result = Any[]
result = []
for i in unflattened
if isa(i, Tuple)
for j in i
Expand Down
62 changes: 42 additions & 20 deletions src/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,66 @@ const VERSIONS = Dict("Julia" => string(VERSION),
"BenchmarkTools" => string(BENCHMARKTOOLS_VERSION))

# TODO: Add any new types as they're added
const SUPPORTED_TYPES = [Benchmark, BenchmarkGroup, Parameters, TagFilter, Trial,
TrialEstimate, TrialJudgement, TrialRatio]
const SUPPORTED_TYPES = Dict{Symbol,Type}(Base.typename(x).name => x for x in [
BenchmarkGroup, Parameters, TagFilter, Trial,
TrialEstimate, TrialJudgement, TrialRatio])
# n.b. Benchmark type not included here, since it is gensym'd

for T in SUPPORTED_TYPES
@eval function JSON.lower(x::$T)
d = Dict{String,Any}()
for i = 1:nfields(x)
name = String(fieldname($T, i))
field = getfield(x, i)
value = typeof(field) in SUPPORTED_TYPES ? JSON.lower(field) : field
push!(d, name => value)
end
[string(typeof(x)), d]
function JSON.lower(x::Union{values(SUPPORTED_TYPES)...})
d = Dict{String,Any}()
T = typeof(x)
for i = 1:nfields(x)
name = String(fieldname(T, i))
field = getfield(x, i)
ft = typeof(field)
value = ft <: get(SUPPORTED_TYPES, ft.name.name, Union{}) ? JSON.lower(field) : field
d[name] = value
end
[string(typeof(x).name.name), d]
end

# a minimal 'eval' function, mirroring KeyTypes, but being slightly more lenient
safeeval(@nospecialize x) = x
safeeval(x::QuoteNode) = x.value
function safeeval(x::Expr)
x.head === :quote && return x.args[1]
x.head === :inert && return x.args[1]
x.head === :tuple && return ((safeeval(a) for a in x.args)...,)
x
end
function recover(x::Vector)
length(x) == 2 || throw(ArgumentError("Expecting a vector of length 2"))
typename = x[1]::String
fields = x[2]::Dict
T = Core.eval(@__MODULE__, Meta.parse(typename))::Type
startswith(typename, "BenchmarkTools.") && (typename = typename[sizeof("BenchmarkTools.")+1:end])
T = SUPPORTED_TYPES[Symbol(typename)]
fc = fieldcount(T)
xs = Vector{Any}(undef, fc)
for i = 1:fc
ft = fieldtype(T, i)
fn = String(fieldname(T, i))
xs[i] = if ft in SUPPORTED_TYPES
recover(fields[fn])
if ft <: get(SUPPORTED_TYPES, ft.name.name, Union{})
xsi = recover(fields[fn])
else
convert(ft, fields[fn])
xsi = convert(ft, fields[fn])
end
if T == BenchmarkGroup && xs[i] isa Dict
for (k, v) in xs[i]
if T == BenchmarkGroup && xsi isa Dict
for (k, v) in copy(xsi)
k = k::String
if startswith(k, "(") || startswith(k, ":")
kt = Meta.parse(k, raise=false)
if !(kt isa Expr && kt.head === :error)
delete!(xsi, k)
k = safeeval(kt)
xsi[k] = v
end
end
if v isa Vector && length(v) == 2 && v[1] isa String
xs[i][k] = recover(v)
xsi[k] = recover(v)
end
end
end
xs[i] = xsi
end
T(xs...)
end
Expand Down Expand Up @@ -73,7 +95,7 @@ function save(io::IO, args...)
"The name will be ignored and the object will be serialized " *
"in the order it appears in the input.")
continue
elseif !any(T->arg isa T, SUPPORTED_TYPES)
elseif !(arg isa get(SUPPORTED_TYPES, typeof(arg).name.name, Union{}))
throw(ArgumentError("Only BenchmarkTools types can be serialized."))
end
push!(goodargs, arg)
Expand Down
2 changes: 1 addition & 1 deletion test/GroupsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ gnest = BenchmarkGroup(["1"],
10 => BenchmarkGroup(["3"]),
11 => BenchmarkGroup()))

@test sort(leaves(gnest), by=string) ==
@test sort(leaves(gnest), by=string) ==
Any[(Any["2",1],1), (Any["a","a"],:a), (Any["a",(11,"b")],:b), (Any[4,5],6), (Any[7],8)]

@test gnest[@tagged 11 || 10] == BenchmarkGroup(["1"],
Expand Down
14 changes: 7 additions & 7 deletions test/SerializationTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module SerializationTests
using BenchmarkTools
using Test

eq(x::T, y::T) where {T<:Union{BenchmarkTools.SUPPORTED_TYPES...}} =
eq(x::T, y::T) where {T<:Union{values(BenchmarkTools.SUPPORTED_TYPES)...}} =
all(i->eq(getfield(x, i), getfield(y, i)), 1:fieldcount(T))
eq(x::T, y::T) where {T} = isapprox(x, y)

Expand All @@ -25,13 +25,13 @@ end
withtempdir() do
tmp = joinpath(pwd(), "tmp.json")

BenchmarkTools.save(tmp, b, bb)
BenchmarkTools.save(tmp, b.params, bb)
@test isfile(tmp)

results = BenchmarkTools.load(tmp)
@test results isa Vector{Any}
@test length(results) == 2
@test eq(results[1], b)
@test eq(results[1], b.params)
@test eq(results[2], bb)
end

Expand All @@ -56,18 +56,18 @@ end
tune!(b)
bb = run(b)

@test_throws ArgumentError BenchmarkTools.save("x.jld", b)
@test_throws ArgumentError BenchmarkTools.save("x.txt", b)
@test_throws ArgumentError BenchmarkTools.save("x.jld", b.params)
@test_throws ArgumentError BenchmarkTools.save("x.txt", b.params)
@test_throws ArgumentError BenchmarkTools.save("x.json")
@test_throws ArgumentError BenchmarkTools.save("x.json", 1)

withtempdir() do
tmp = joinpath(pwd(), "tmp.json")
@test_logs (:warn, r"Naming variables") BenchmarkTools.save(tmp, "b", b)
@test_logs (:warn, r"Naming variables") BenchmarkTools.save(tmp, "b", b.params)
@test isfile(tmp)
results = BenchmarkTools.load(tmp)
@test length(results) == 1
@test eq(results[1], b)
@test eq(results[1], b.params)
end

@test_throws ArgumentError BenchmarkTools.load("x.jld")
Expand Down

0 comments on commit a385dcc

Please sign in to comment.