Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompatHelper: add new compat entry for BenchmarkTools at version 1, (keep existing compat) #2

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@ uuid = "21e319bb-9f6a-4cc8-b556-663507cd964d"
authors = ["Ian Limarta"]
version = "1.0.0-DEV"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Gen = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
BenchmarkTools = "1"
julia = "1"

[extras]
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# GenSerialization

[![Build Status](https://github.com/limarta/GenSerialization.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/limarta/GenSerialization.jl/actions/workflows/CI.yml?query=branch%3Amain)

WIP. Influenced by JLD2.jl. Some snippets derived from its repository. Refer to https://docs.hdfgroup.org/hdf5/develop/_f_m_t3.html
33 changes: 33 additions & 0 deletions example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Gen
using GenSerialization
using BenchmarkTools

@gen function slow(n::Int)
sleep(n)
subchoice ~ normal(0,0.5)
return 1
end

@gen function model(n::Int)
x ~ bernoulli(0.5)
y ~ slow(n)
{:k=>1} ~ bernoulli(0.5)
return (x,y)
end

tr = simulate(model, (1,))
serialize("test.gen", tr)
@time realized_tr = realize("test.gen", model)
@time deserialized_tr = deserialize("test.gen")

@gen function model(n::Int)
for i=1:n
{:k=>i} ~ mvnormal([0.0,0.0], [1.0 0.0; 0.0 1.0])
end
n
end
@btime simulate($model,(1000,))
tr = simulate(model, (1000,))
@btime coarse_serialize("test.gen", $tr)

@btime coarse_deserialize("test.gen")
17 changes: 16 additions & 1 deletion src/GenSerialization.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
module GenSerialization
using Gen

# Write your package code here.
function io_size(io::IO)
cuptr = position(io)
seekend(io)
size = position(io)
seek(io, cuptr)
return size
end

include("gen_file.jl")
include("write_session.jl")
include("dsl/dsl.jl")
include("lazy/lazy.jl")
include("file_header.jl")
include("serialization.jl")
include("deserialization.jl")

end
3 changes: 3 additions & 0 deletions src/custom_serialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
writeas(T::Type) = T

# Refer to hdf5 serialization
30 changes: 30 additions & 0 deletions src/datatypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
const Plain = Union{Int16,Int32,Int64,Int128,UInt16,UInt32,UInt64,UInt128,Float16,Float32,
Float64}
const PlainType = Union{Type{Int16},Type{Int32},Type{Int64},Type{Int128},Type{UInt16},
Type{UInt32},Type{UInt64},Type{UInt128},Type{Float16},
Type{Float32},Type{Float64}}

# WriteDataspace() = WriteDataspace(DS_NULL, (), ())
# WriteDataspace(::JLDFile, ::Any, odr::Nothing) = WriteDataspace()
# WriteDataspace(::JLDFile, ::Any, ::Any) = WriteDataspace(DS_SCALAR, (), ())

# # Ghost type array
# WriteDataspace(f::JLDFile, x::Array{T}, ::Nothing) where {T} =
# WriteDataspace(DS_NULL, (),
# (WrittenAttribute(f, :dimensions, collect(Int64, reverse(size(x)))),))

# # Reference array
# WriteDataspace(f::JLDFile, x::Array{T,N}, ::Type{RelOffset}) where {T,N} =
# WriteDataspace(DS_SIMPLE, convert(Tuple{Vararg{Length}}, reverse(size(x))),
# (WrittenAttribute(f, :julia_type, write_ref(f, T, f.datatype_wsession)),))

# # isbitstype array
# WriteDataspace(f::JLDFile, x::Array, ::Any) =
# WriteDataspace(DS_SIMPLE, convert(Tuple{Vararg{Length}}, reverse(size(x))), ())

# # Zero-dimensional arrays need an empty dimensions attribute
# WriteDataspace(f::JLDFile, x::Array{T,0}, ::Nothing) where {T} =
# WriteDataspace(DS_NULL, (Length(1),),
# (WrittenAttribute(f, :dimensions, EMPTY_DIMENSIONS)))
# WriteDataspace(f::JLDFile, x::Array{T,0}, ::Type{RelOffset}) where {T} =
# WriteDataspace(DS_SIMPLE, (Length(1),),
32 changes: 32 additions & 0 deletions src/deserialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
function deserialize_trace(io::IO)
cuptr = position(io)
trace_type = Serialization.deserialize(io)
seek(io, cuptr)
deserialize_trace(io, trace_type)
end

function deserialize(fname::AbstractString)
genopen(fname, "r") do f
header_length = verify_file_header(f)
f.header_length = header_length
tr = deserialize_trace(f.io)
end
end

function realize(fname::AbstractString, gen_fn::GenerativeFunction)
genopen(fname, "r") do f
header_length = verify_file_header(f)
f.header_length = header_length
tr = realize_trace(f.io, gen_fn)
end
end

function coarse_deserialize(fname::AbstractString)
genopen(fname, "r") do f
header_length = verify_file_header(f)
f.header_length = header_length
tr = Serialization.deserialize(f.io)
end
end

export deserialize, realize, coarse_deserialize
Empty file.
36 changes: 36 additions & 0 deletions src/dsl/combinators/map/realization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
mutable struct MapDeserializeState{T,U}
score::Float64
noise::Float64
subtraces::Vector{U}
retval::Vector{T}
num_nonempty::Int
end

function realize_trace(io::IO, gen_fn::Map{T,U}) where {T,U}
trace_type = Serialization.deserialize(io)
!(trace_type <: Gen.VectorTrace) && error("Expected VectorTrace, got $trace_type")
retval = Serialization.deserialize(io)
args = Serialization.deserialize(io)
len = read(io, Int)
num_nonempty = read(io, Int)
score = read(io, Float64)
noise = read(io, Float64)
len = length(args[1])
base_ptr = position(io)
state = MapDeserializeState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0)
for key=1:len
seek(io, base_ptr + (key-1)*sizeof(Int))
tr_ptr = read(io, Int)
seek(io, tr_ptr)
subtrace = realize_trace(io, gen_fn.kernel)
state.subtraces[key] = subtrace
retval = get_retval(subtrace)
state.retval[key] = retval
end
state.noise = noise
state.num_nonempty = num_nonempty
state.score = score
Gen.VectorTrace{Gen.MapType,T,U}(gen_fn,
Gen.PersistentVector{U}(state.subtraces), Gen.PersistentVector{T}(state.retval),
args, state.score, state.noise, len, state.num_nonempty)
end
Empty file.
Empty file.
Empty file.
19 changes: 19 additions & 0 deletions src/dsl/combinators/switch/realization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# mutable struct SwitchDeserializeState{T}
# score::Float64
# noise::Float64
# index::Int
# subtrace::Trace
# retval::T
# SwitchDeserializeState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise)
# end

# function realize_trace(io::IO, gen_fn::Switch{C, N, K, T}) where {C, N, K, T}
# trace_type = Serialization.deserialize(io)
# !(trace_type <: Gen.VectorTrace) && error("Expected VectorTrace, got $trace_type")
# retval = Serialization.deserialize(io)
# args = Serialization.deserialize(io)
# len = read(io, Int)
# num_nonempty = read(io, Int)
# score = read(io, Float64)
# noise = read(io, Float64)
# end
40 changes: 40 additions & 0 deletions src/dsl/combinators/unfold/realization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
mutable struct UnfoldDeserializeState{T,U}
score::Float64
noise::Float64
subtraces::Vector{U}
retval::Vector{T}
num_nonempty::Int
state::T
end

function realize_trace(io::IO, gen_fn::Unfold{T,U}) where {T,U}
trace_type = Serialization.deserialize(io)
!(trace_type <: Gen.VectorTrace) && error("Expected VectorTrace, got $trace_type")
retval = Serialization.deserialize(io)
args = Serialization.deserialize(io)
len = read(io, Int)
num_nonempty = read(io, Int)
score = read(io, Float64)
noise = read(io, Float64)

init_state = args[2]
params = args[3:end]
state = UnfoldDeserializeState{T,U}(0., 0.,
Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state)
base_ptr = position(io)
for key=1:len
seek(io, base_ptr + (key-1)*sizeof(Int))
tr_ptr = read(io, Int)
seek(io, tr_ptr)
subtrace = realize_trace(io, gen_fn.kernel)
state.subtraces[key] = subtrace
retval = get_retval(subtrace)
state.retval[key] = retval
end
state.noise = noise
state.num_nonempty = num_nonempty
state.score = score
Gen.VectorTrace{Gen.MapType,T,U}(gen_fn,
Gen.PersistentVector{U}(state.subtraces), Gen.PersistentVector{T}(state.retval),
args, state.score, state.noise, len, state.num_nonempty)
end
43 changes: 43 additions & 0 deletions src/dsl/combinators/vector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
function serialize_vector(io::IO, vector, ws::WriteSession)
# Pointers to subtraces
base_ptr = position(io)
for i=1:length(vector)
genwrite(io, 0, ws)
end
# println("ws before serializing $(length(ws))")
for (i, tr) in enumerate(vector)
cuptr = position(io)
blob_len = serialize_trace(io, tr)

# Manually update ws with blob_len
ws.max = max(ws.max, cuptr+blob_len)
seek(io, base_ptr + (i-1)*sizeof(Int))
genwrite(io, cuptr, ws)
seek(io, cuptr+blob_len)
end
end

function serialize_trace(io::IO, tr::Gen.VectorTrace{C, U, V}) where {C,U,V}
cuptr = position(io)
ws = WriteSession(cuptr)
genwrite(io, typeof(tr), ws, Val{:serialized}())
genwrite(io, tr.retval, ws, Val{:serialized}())
genwrite(io, tr.args, ws, Val{:serialized}())
genwrite(io, tr.len, ws)
genwrite(io, tr.num_nonempty, ws)
genwrite(io, tr.score, ws)
genwrite(io, tr.noise, ws)
serialize_vector(io, tr.subtraces, ws)
seek(io, cuptr)

return length(ws)
end

include("map/realization.jl")
include("map/deserialization.jl")
include("unfold/realization.jl")
# include("unfold/deserialization.jl")
# include("switch/realization.jl")
# include("switch/deserialization.jl")
# include("recurse/realization.jl")
# include("recurse/deserialization.jl")
3 changes: 3 additions & 0 deletions src/dsl/dsl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

include("dynamic/dynamic.jl")
include("combinators/vector.jl")
Loading