Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New tracer based on JuliaInterpreter #40

Merged
merged 6 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ git-tree-sha1 = "1d00c35118babf85c4a9a72c3d3550d498b64a03"
repo-rev = "45a4f68c304c5ac39543338f7281812a689adf35"
repo-url = "https://github.com/jrevels/Cassette.jl.git"
uuid = "7057c7e9-c182-5462-911a-8362d720325c"
version = "0.2.2+"
version = "0.2.5+"

[[CodeTracking]]
deps = ["InteractiveUtils", "Test", "UUIDs"]
git-tree-sha1 = "9b21a2dfe51ba71fdc5688039075819196595367"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "0.5.7"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Espresso]]
Expand All @@ -24,9 +30,15 @@ uuid = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
version = "0.6.0"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "ed46097f465a091f6b126966015048193791743a"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.6.0"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -51,6 +63,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand All @@ -68,3 +83,7 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ print(tape)
# %3 = broadcast(%2, %1)::Array{Float64,1}
# %4 = sum(%3)::Float64
```
`trace` uses [Cassette.jl](https://github.com/jrevels/Cassette.jl/) to collect function calls during execution. Functions are divided into 2 groups:
`trace` uses [JuliaInterpreter.jl](https://github.com/JuliaDebug/JuliaInterpreter.jl) to collect function calls during execution. Functions are divided into 2 groups:

* primitive, which are recorded to the tape;
* non-primitive, which are traced-through down to primitive ones.
Expand All @@ -112,6 +112,8 @@ compile!(tape)
# 492.063 ns (2 allocations: 144 bytes)
```

Note that `trace()` is an alias to `itrace()` - JuliaInterpreter-based tracer. Older versions of Yota used another implementation with identical interface and capabilities, but based on [Cassette.jl](https://github.com/jrevels/Cassette.jl). This implementation is still available by name `ctrace()`.

## CuArrays support (experimental)

Yota should work with CuArrays out of the box, although integration is not well tested yet.
Expand Down
7 changes: 1 addition & 6 deletions src/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ end

function to_exnode(op::Call)
arg_names = map(make_name, op.args)
# fns = maybe_to_symbol(op.fn)
if isempty(op.kwargs)
ex = Expr(:call, op.fn, arg_names...)
else
ex = Expr(:call, op.fn, Espresso.make_kw_params(op.kwargs), arg_names...)
end
ex = Expr(:call, op.fn, arg_names...)
return ExNode{:call}(make_name(op), ex; val=op.val)
end

Expand Down
2 changes: 1 addition & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include("helpers.jl")
include("devices.jl")
include("tape.jl")
include("tapeutils.jl")
include("trace.jl")
include("trace/trace.jl")
include("diffrules.jl")
include("grad.jl")
include("compile.jl")
Expand Down
12 changes: 5 additions & 7 deletions src/diffrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ end
@diffrule tuple(x,y,z,t) z ds[3]
@diffrule tuple(x,y,z,t) t ds[4]

# __tuple__ (current tracer implemetation uses it instead of normal tuple)
# __tuple__ (cassette tracer implemetation uses it instead of normal tuple)
@diffrule __tuple__(x) x ds[1]
@diffrule __tuple__(x,y) x ds[1]
@diffrule __tuple__(x,y) y ds[2]
Expand Down Expand Up @@ -313,28 +313,26 @@ end

# binary substraction
@diffrule -(x::Real, y::Real) x ds
# @diffrule -(x::Real, y::AbstractArray) x sum(ds)
# @diffrule -(x::AbstractArray, y::Real) x ones(size(x)) .* ds
@diffrule -(x::AbstractArray, y::AbstractArray) x ds
@diffrule -(x::Real , y::Real) y -ds
# @diffrule -(x::Real, y::AbstractArray) y -ones(size(y)) .* ds
# @diffrule -(x::AbstractArray, y::Real) y -sum(ds)
@diffrule -(x::AbstractArray, y::AbstractArray) y -ds


# sum() and mean()
# @diffrule sum(x::Real) x ds
@diffrule sum(x::AbstractArray) x sum_grad(x, ds)
@diffrule Base._sum(x::AbstractArray, y::Int) x sum_grad(x, ds)
@diffrule Base._sum(x::AbstractArray, y::Int) y zero(eltype(x))
@diffrule Core.kwfunc(sum)(_dims, _, x::AbstractArray) x sum_grad(x, ds)

# special sums
@diffrule sum(_fn::typeof(log), x::AbstractArray) x sum_grad(x, ds) ./ x

# @diffrule Statistics.mean(x::Real) x ds
@diffrule Statistics.mean(x::AbstractArray) x mean_grad(x, ds)
@diffrule Statistics._mean(x::AbstractArray, y::Int) x mean_grad(x, ds)
@diffrule Statistics._mean(x::AbstractArray, y::Int) y zero(eltype(x))
@diffrule Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) x mean_grad(x, ds)
@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _dims
@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _

# diag
@diffrule diag(x::AbstractMatrix) x Diagonal(ds)
Expand Down
3 changes: 1 addition & 2 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ setderiv!(tape::Tape, op::AbstractOp, grad_op::AbstractOp) = (tape.derivs[op.id]


Espresso.to_expr(tape::Tape, op::Call) = begin
@assert isempty(op.kwargs) "Oops, functions with kwargs aren't supported just yet"
Expr(:call, op.fn, [Symbol("%$i") for i in op.args]...)
end

Expand Down Expand Up @@ -181,7 +180,7 @@ function back!(tape::Tape)
end
end
elseif op.fn == __getfield__
# unstructuring if tuples is lowere into pretty weird code sequence
# unstructuring of tuples is lowered into pretty weird code sequence
# ending with __getfield__; similar to getproperty(), we find source var
# for the corresponding tuple argument and backprop to it
if haskey(tape.derivs, op.id)
Expand Down
2 changes: 2 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,5 @@ unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y([x], y,
untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds)
untranspose_vec(ds::Adjoint{T, <:AbstractVector{T}}) where T = adjoint(ds)
untranspose_vec(ds::AbstractMatrix) = dropdims(transpose(ds); dims=2)

namedtuple(names, values) = NamedTuple{names}(values)
14 changes: 5 additions & 9 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,16 @@ mutable struct Call <: AbstractOp
val::Any
fn::Union{Function, Type}
args::Vector{Int}
kwargs::Dict # currently not used
end

Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}; kwargs=Dict()) =
Call(id, val, fn, args, kwargs)
# Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}) =
# Call(id, val, fn, args)
Base.getproperty(op::Input, f::Call) = f == :typ ? typeof(op.val) : getfield(op, f)


function Base.show(io::IO, op::Call)
arg_str = join(["%$(id)" for id in op.args], ", ")
kwarg_str = (isempty(op.kwargs) ? "" : "; " *
join(["$k=$v" for (k, v) in op.kwargs], ", "))
print(io, "%$(op.id) = $(op.fn)($arg_str$kwarg_str)::$(op.typ)")
print(io, "%$(op.id) = $(op.fn)($arg_str)::$(op.typ)")
end


Expand Down Expand Up @@ -159,7 +156,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false)
arg_id = record!(tape, Constant, x)
new_op_args[i] = arg_id
end
end
end
fn = ex.args[1]
fn = device_function(tape.device, fn)
if bcast
Expand Down Expand Up @@ -203,8 +200,7 @@ function replace_in_args!(tape::Tape, st::Dict)
for (i, op) in enumerate(tape)
if op isa Call
new_args = [get(st, x, x) for x in op.args]
new_kwargs = Dict(k => get(st, x, x) for (k, x) in op.kwargs)
tape[i] = copy_with(op, args=new_args, kwargs=new_kwargs)
tape[i] = copy_with(op, args=new_args)
end
end
end
Expand Down
27 changes: 1 addition & 26 deletions src/trace.jl → src/trace/cassette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,6 @@ Cassette.hastagging(::Type{<:TraceCtx}) = true
# CUSTOM PASS #
########################################################################

function __new__(T, args...)
# @show T
# @show args
# note: we also add __new__() to the list of primitives so it's not overdubbed recursively
if T <: NamedTuple
return T(args)
else
return T(args...)
end
end


__tuple__(args...) = tuple(args...)
__getfield__(args...) = getfield(args...)




is_gref_call(a, fn_name) = a isa GlobalRef && a.name == fn_name

Expand All @@ -55,14 +38,6 @@ end
# TRACE #
########################################################################

const PRIMITIVES = Set([
*, /, +, -, sin, cos, sum, Base._sum,
println,
Base.getproperty, Base.getfield, Base.indexed_iterate, # Core.kwfunc,
broadcast, Broadcast.materialize, Broadcast.broadcasted,
__new__, __tuple__, __getfield__])


struct TapeBox
tape::Tape
primitives::Set{Any}
Expand All @@ -78,7 +53,7 @@ foo(x) = 2.0x + 1.0
val, tape = trace(foo, 4.0)
```
"""
function trace(f, args...; primitives=PRIMITIVES, optimize=true)
function ctrace(f, args...; primitives=PRIMITIVES, optimize=true)
# create tape
tape = Tape(guess_device(args))
box = TapeBox(tape, primitives)
Expand Down
132 changes: 132 additions & 0 deletions src/trace/interp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import JuliaInterpreter
import JuliaInterpreter: enter_call, step_expr!, @lookup, Frame, SSAValue, SlotNumber


getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc]
current_expr(fr::Frame) = getexpr(fr, fr.pc)


"""
Split JuliaInterpreter call expression into a tuple of 3 elements:

* function to be called
* args to this function
* vars on the tape corresponding to these args

If arguments include free parameters (not SlotNumber or SSAValue), these are recorded
to the tape as constants
"""
function split_int_call!(tape::Tape, fr::Frame, frame_vars::Dict, ex)
arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args
# for whatever reason JuliaInterpreter wraps some nodes in the original code into QuoteNode
arr = [isa(x, QuoteNode) ? x.value : x for x in arr]
cf = @lookup(fr, arr[1])
cargs = [x isa Symbol ? x : @lookup(fr, x) for x in arr[2:end]]
cvars = Vector{Int}(undef, length(cargs))
for (i, x) in enumerate(arr[2:end])
# if isa(x, JuliaInterpreter.SlotNumber) || isa(x, JuliaInterpreter.SSAValue)
if haskey(frame_vars, x)
cvars[i] = frame_vars[x]
else
val = x isa Symbol ? x : @lookup(fr, x)
id = record!(tape, Constant, val)
cvars[i] = id
if val != x
# if constant appeared to be a SlotNumber or SSAValue
# store its mapping into frame_vars
frame_vars[x] = id
end
end
end
return cf, cargs, cvars
end


"""
Given a Frame and current expression, extract LHS location (SlotNumber or SSAValue)
"""
get_location(fr::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc)

is_int_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call))
is_int_assign_expr(ex) = Meta.isexpr(ex, :(=)) && (isa(ex.args[2], SlotNumber) || isa(ex.args[2], SSAValue))

is_interesting_expr(ex) = is_int_call_expr(ex) || is_int_assign_expr(ex) || Meta.isexpr(ex, :return)


function itrace!(f, tape::Tape, argvars...; primitives)
args, vars = zip(argvars...)
fr = enter_call(f, args...)
frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars))
is_interesting_expr(current_expr(fr)) || step_expr!(fr) # skip non-call expressions
ex = current_expr(fr)
while !Meta.isexpr(ex, :return)
if is_int_assign_expr(ex)
lhs, rhs = ex.args
frame_vars[lhs] = frame_vars[rhs]
step_expr!(fr)
elseif is_int_call_expr(ex)
# read as "current function", "current arguments", "current variables"
cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex)
loc = get_location(fr, ex)
# there are several special cases such as NamedTuples and constructors
# we replace these with calls to special helper functions
if cf isa UnionAll && cf <: NamedTuple
# replace cf with namedtuple function, adjust arguments
names = collect(cf.body.parameters)[1]
cf = namedtuple
cargs = [names; cargs]
names_var_id = record!(tape, Constant, names)
cvars = [names_var_id; cvars]
elseif cf isa DataType
# constructor, replace with a call to __new__ which we know how to differentiate
T = cf
cf = __new__
cargs = [T; cargs]
T_var_id = record!(tape, Constant, T)
cvars = [T_var_id; cvars]
elseif cf == Base.tuple
cf = __tuple__
elseif cf == Base.getfield
# similar to constuctors, there's a special case for __getfield__ in backprop
cf = __getfield__
end
# if current function is a primitive of a built-in, write it to the tape
# otherwise recurse into the current function
if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction)
step_expr!(fr)
retval = @lookup(fr, loc)
ret_id = record!(tape, Call, retval, cf, cvars)
frame_vars[loc] = ret_id # for slots it may overwrite old mapping
else
retval, ret_id = itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives)
frame_vars[loc] = ret_id # for slots it may overwrite old mapping
step_expr!(fr) # can we avoid this double execution?
end
else
step_expr!(fr)
end
ex = current_expr(fr)
end
retval = @lookup(fr, ex.args[1])
ret_id = frame_vars[ex.args[1]]
return retval, ret_id # return var ID of a result variable
end


"""
Trace function f with arguments args using JuliaInterpreter
"""
function itrace(f, args...; primitives=PRIMITIVES, optimize=true)
tape = Tape(guess_device(args))
argvars = Vector(undef, length(args))
for (i, arg) in enumerate(args)
id = record!(tape, Input, arg)
argvars[i] = (arg, id)
end
val, resultid = itrace!(f, tape, argvars...; primitives=primitives)
tape.resultid = resultid
if optimize
tape = simplify(tape)
end
return val, tape
end
Loading