Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #33466, serialization of IdDict #33473

Merged
merged 1 commit into from
Oct 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Serializer(io::IO) = Serializer{typeof(io)}(io)

const n_int_literals = 33
const n_reserved_slots = 24
const n_reserved_tags = 12
const n_reserved_tags = 11

const TAGS = Any[
Symbol, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128,
Expand All @@ -56,6 +56,7 @@ const TAGS = Any[
Symbol, # REF_OBJECT_TAG
Symbol, # FULL_GLOBALREF_TAG
Symbol, # HEADER_TAG
Symbol, # IDDICT_TAG
fill(Symbol, n_reserved_tags)...,

(), Bool, Any, Bottom, Core.TypeofBottom, Type, svec(), Tuple{}, false, true, nothing,
Expand All @@ -75,7 +76,7 @@ const TAGS = Any[

@assert length(TAGS) == 255

const ser_version = 8 # do not make changes without bumping the version #!
const ser_version = 9 # do not make changes without bumping the version #!

const NTAGS = length(TAGS)

Expand Down Expand Up @@ -133,6 +134,7 @@ const OBJECT_TAG = Int32(o0+12)
const REF_OBJECT_TAG = Int32(o0+13)
const FULL_GLOBALREF_TAG = Int32(o0+14)
const HEADER_TAG = Int32(o0+15)
const IDDICT_TAG = Int32(o0+16)

writetag(s::IO, tag) = (write(s, UInt8(tag)); nothing)

Expand Down Expand Up @@ -327,15 +329,26 @@ function serialize(s::AbstractSerializer, ex::Expr)
end
end

function serialize(s::AbstractSerializer, d::Dict)
serialize_cycle_header(s, d) && return
function serialize_dict_data(s::AbstractSerializer, d::AbstractDict)
write(s.io, Int32(length(d)))
for (k,v) in d
serialize(s, k)
serialize(s, v)
end
end

function serialize(s::AbstractSerializer, d::Dict)
serialize_cycle_header(s, d) && return
serialize_dict_data(s, d)
end

function serialize(s::AbstractSerializer, d::IdDict)
serialize_cycle(s, d) && return
writetag(s.io, IDDICT_TAG)
serialize_type_data(s, typeof(d))
serialize_dict_data(s, d)
end

function serialize_mod_names(s::AbstractSerializer, m::Module)
p = parentmodule(m)
if p === m || m === Base
Expand Down Expand Up @@ -851,6 +864,11 @@ function handle_deserialize(s::AbstractSerializer, b::Int32)
return read(s.io, Float64)
elseif b == INT8_TAG+13
return read(s.io, Char)
elseif b == IDDICT_TAG
slot = s.counter; s.counter += 1
push!(s.pending_refs, slot)
t = deserialize(s)
return deserialize_dict(s, t)
end
t = desertag(b)::DataType
if t.mutable && length(t.types) > 0 # manual specialization of fieldcount
Expand Down Expand Up @@ -1303,7 +1321,7 @@ function deserialize(s::AbstractSerializer, t::DataType)
end
end

function deserialize(s::AbstractSerializer, T::Type{Dict{K,V}}) where {K,V}
function deserialize_dict(s::AbstractSerializer, T::Type{<:AbstractDict})
n = read(s.io, Int32)
t = T(); sizehint!(t, n)
deserialize_cycle(s, t)
Expand All @@ -1315,6 +1333,10 @@ function deserialize(s::AbstractSerializer, T::Type{Dict{K,V}}) where {K,V}
return t
end

function deserialize(s::AbstractSerializer, T::Type{Dict{K,V}}) where {K,V}
return deserialize_dict(s, T)
end

deserialize(s::AbstractSerializer, ::Type{BigInt}) = parse(BigInt, deserialize(s), base = 62)

function deserialize(s::AbstractSerializer, t::Type{Regex})
Expand Down
9 changes: 9 additions & 0 deletions stdlib/Serialization/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,12 @@ let f_data
f = deserialize(IOBuffer(base64decode(f_data)))
@test f(10,3) == 23
end

# issue #33466, IdDict
let d = IdDict([1] => 2, [3] => 4), io = IOBuffer()
serialize(io, d)
seekstart(io)
ds = deserialize(io)
@test Dict(d) == Dict(ds)
@test all([k in keys(ds) for k in keys(ds)])
end