Skip to content

Commit

Permalink
Merge pull request #40 from dfdx/int-tracer
Browse files Browse the repository at this point in the history
New tracer based on JuliaInterpreter
  • Loading branch information
dfdx authored Aug 13, 2019
2 parents 0a9bedc + 76a87ce commit fcfae2b
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 57 deletions.
25 changes: 22 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ git-tree-sha1 = "1d00c35118babf85c4a9a72c3d3550d498b64a03"
repo-rev = "45a4f68c304c5ac39543338f7281812a689adf35"
repo-url = "https://github.com/jrevels/Cassette.jl.git"
uuid = "7057c7e9-c182-5462-911a-8362d720325c"
version = "0.2.2+"
version = "0.2.5+"

[[CodeTracking]]
deps = ["InteractiveUtils", "Test", "UUIDs"]
git-tree-sha1 = "9b21a2dfe51ba71fdc5688039075819196595367"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "0.5.7"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Espresso]]
Expand All @@ -24,9 +30,15 @@ uuid = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
version = "0.6.0"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "ed46097f465a091f6b126966015048193791743a"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.6.0"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -51,6 +63,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand All @@ -68,3 +83,7 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ print(tape)
# %3 = broadcast(%2, %1)::Array{Float64,1}
# %4 = sum(%3)::Float64
```
`trace` uses [Cassette.jl](https://github.com/jrevels/Cassette.jl/) to collect function calls during execution. Functions are divided into 2 groups:
`trace` uses [JuliaInterpreter.jl](https://github.com/JuliaDebug/JuliaInterpreter.jl) to collect function calls during execution. Functions are divided into 2 groups:

* primitive, which are recorded to the tape;
* non-primitive, which are traced-through down to primitive ones.
Expand All @@ -112,6 +112,8 @@ compile!(tape)
# 492.063 ns (2 allocations: 144 bytes)
```

Note that `trace()` is an alias to `itrace()` - JuliaInterpreter-based tracer. Older versions of Yota used another implementation with identical interface and capabilities, but based on [Cassette.jl](https://github.com/jrevels/Cassette.jl). This implementation is still available by name `ctrace()`.

## CuArrays support (experimental)

Yota should work with CuArrays out of the box, although integration is not well tested yet.
Expand Down
7 changes: 1 addition & 6 deletions src/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ end

function to_exnode(op::Call)
arg_names = map(make_name, op.args)
# fns = maybe_to_symbol(op.fn)
if isempty(op.kwargs)
ex = Expr(:call, op.fn, arg_names...)
else
ex = Expr(:call, op.fn, Espresso.make_kw_params(op.kwargs), arg_names...)
end
ex = Expr(:call, op.fn, arg_names...)
return ExNode{:call}(make_name(op), ex; val=op.val)
end

Expand Down
2 changes: 1 addition & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include("helpers.jl")
include("devices.jl")
include("tape.jl")
include("tapeutils.jl")
include("trace.jl")
include("trace/trace.jl")
include("diffrules.jl")
include("grad.jl")
include("compile.jl")
Expand Down
12 changes: 5 additions & 7 deletions src/diffrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ end
@diffrule tuple(x,y,z,t) z ds[3]
@diffrule tuple(x,y,z,t) t ds[4]

# __tuple__ (current tracer implemetation uses it instead of normal tuple)
# __tuple__ (cassette tracer implemetation uses it instead of normal tuple)
@diffrule __tuple__(x) x ds[1]
@diffrule __tuple__(x,y) x ds[1]
@diffrule __tuple__(x,y) y ds[2]
Expand Down Expand Up @@ -313,28 +313,26 @@ end

# binary substraction
@diffrule -(x::Real, y::Real) x ds
# @diffrule -(x::Real, y::AbstractArray) x sum(ds)
# @diffrule -(x::AbstractArray, y::Real) x ones(size(x)) .* ds
@diffrule -(x::AbstractArray, y::AbstractArray) x ds
@diffrule -(x::Real , y::Real) y -ds
# @diffrule -(x::Real, y::AbstractArray) y -ones(size(y)) .* ds
# @diffrule -(x::AbstractArray, y::Real) y -sum(ds)
@diffrule -(x::AbstractArray, y::AbstractArray) y -ds


# sum() and mean()
# @diffrule sum(x::Real) x ds
@diffrule sum(x::AbstractArray) x sum_grad(x, ds)
@diffrule Base._sum(x::AbstractArray, y::Int) x sum_grad(x, ds)
@diffrule Base._sum(x::AbstractArray, y::Int) y zero(eltype(x))
@diffrule Core.kwfunc(sum)(_dims, _, x::AbstractArray) x sum_grad(x, ds)

# special sums
@diffrule sum(_fn::typeof(log), x::AbstractArray) x sum_grad(x, ds) ./ x

# @diffrule Statistics.mean(x::Real) x ds
@diffrule Statistics.mean(x::AbstractArray) x mean_grad(x, ds)
@diffrule Statistics._mean(x::AbstractArray, y::Int) x mean_grad(x, ds)
@diffrule Statistics._mean(x::AbstractArray, y::Int) y zero(eltype(x))
@diffrule Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) x mean_grad(x, ds)
@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _dims
@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _

# diag
@diffrule diag(x::AbstractMatrix) x Diagonal(ds)
Expand Down
3 changes: 1 addition & 2 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ setderiv!(tape::Tape, op::AbstractOp, grad_op::AbstractOp) = (tape.derivs[op.id]


Espresso.to_expr(tape::Tape, op::Call) = begin
@assert isempty(op.kwargs) "Oops, functions with kwargs aren't supported just yet"
Expr(:call, op.fn, [Symbol("%$i") for i in op.args]...)
end

Expand Down Expand Up @@ -181,7 +180,7 @@ function back!(tape::Tape)
end
end
elseif op.fn == __getfield__
# unstructuring if tuples is lowere into pretty weird code sequence
# unstructuring of tuples is lowered into pretty weird code sequence
# ending with __getfield__; similar to getproperty(), we find source var
# for the corresponding tuple argument and backprop to it
if haskey(tape.derivs, op.id)
Expand Down
2 changes: 2 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,5 @@ unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y([x], y,
untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds)
untranspose_vec(ds::Adjoint{T, <:AbstractVector{T}}) where T = adjoint(ds)
untranspose_vec(ds::AbstractMatrix) = dropdims(transpose(ds); dims=2)

namedtuple(names, values) = NamedTuple{names}(values)
14 changes: 5 additions & 9 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,16 @@ mutable struct Call <: AbstractOp
val::Any
fn::Union{Function, Type}
args::Vector{Int}
kwargs::Dict # currently not used
end

Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}; kwargs=Dict()) =
Call(id, val, fn, args, kwargs)
# Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}) =
# Call(id, val, fn, args)
Base.getproperty(op::Input, f::Call) = f == :typ ? typeof(op.val) : getfield(op, f)


function Base.show(io::IO, op::Call)
arg_str = join(["%$(id)" for id in op.args], ", ")
kwarg_str = (isempty(op.kwargs) ? "" : "; " *
join(["$k=$v" for (k, v) in op.kwargs], ", "))
print(io, "%$(op.id) = $(op.fn)($arg_str$kwarg_str)::$(op.typ)")
print(io, "%$(op.id) = $(op.fn)($arg_str)::$(op.typ)")
end


Expand Down Expand Up @@ -159,7 +156,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false)
arg_id = record!(tape, Constant, x)
new_op_args[i] = arg_id
end
end
end
fn = ex.args[1]
fn = device_function(tape.device, fn)
if bcast
Expand Down Expand Up @@ -203,8 +200,7 @@ function replace_in_args!(tape::Tape, st::Dict)
for (i, op) in enumerate(tape)
if op isa Call
new_args = [get(st, x, x) for x in op.args]
new_kwargs = Dict(k => get(st, x, x) for (k, x) in op.kwargs)
tape[i] = copy_with(op, args=new_args, kwargs=new_kwargs)
tape[i] = copy_with(op, args=new_args)
end
end
end
Expand Down
27 changes: 1 addition & 26 deletions src/trace.jl → src/trace/cassette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,6 @@ Cassette.hastagging(::Type{<:TraceCtx}) = true
# CUSTOM PASS #
########################################################################

function __new__(T, args...)
# @show T
# @show args
# note: we also add __new__() to the list of primitives so it's not overdubbed recursively
if T <: NamedTuple
return T(args)
else
return T(args...)
end
end


__tuple__(args...) = tuple(args...)
__getfield__(args...) = getfield(args...)




is_gref_call(a, fn_name) = a isa GlobalRef && a.name == fn_name

Expand All @@ -55,14 +38,6 @@ end
# TRACE #
########################################################################

const PRIMITIVES = Set([
*, /, +, -, sin, cos, sum, Base._sum,
println,
Base.getproperty, Base.getfield, Base.indexed_iterate, # Core.kwfunc,
broadcast, Broadcast.materialize, Broadcast.broadcasted,
__new__, __tuple__, __getfield__])


struct TapeBox
tape::Tape
primitives::Set{Any}
Expand All @@ -78,7 +53,7 @@ foo(x) = 2.0x + 1.0
val, tape = trace(foo, 4.0)
```
"""
function trace(f, args...; primitives=PRIMITIVES, optimize=true)
function ctrace(f, args...; primitives=PRIMITIVES, optimize=true)
# create tape
tape = Tape(guess_device(args))
box = TapeBox(tape, primitives)
Expand Down
132 changes: 132 additions & 0 deletions src/trace/interp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import JuliaInterpreter
import JuliaInterpreter: enter_call, step_expr!, @lookup, Frame, SSAValue, SlotNumber


getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc]
current_expr(fr::Frame) = getexpr(fr, fr.pc)


"""
Split JuliaInterpreter call expression into a tuple of 3 elements:
* function to be called
* args to this function
* vars on the tape corresponding to these args
If arguments include free parameters (not SlotNumber or SSAValue), these are recorded
to the tape as constants
"""
function split_int_call!(tape::Tape, fr::Frame, frame_vars::Dict, ex)
arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args
# for whatever reason JuliaInterpreter wraps some nodes in the original code into QuoteNode
arr = [isa(x, QuoteNode) ? x.value : x for x in arr]
cf = @lookup(fr, arr[1])
cargs = [x isa Symbol ? x : @lookup(fr, x) for x in arr[2:end]]
cvars = Vector{Int}(undef, length(cargs))
for (i, x) in enumerate(arr[2:end])
# if isa(x, JuliaInterpreter.SlotNumber) || isa(x, JuliaInterpreter.SSAValue)
if haskey(frame_vars, x)
cvars[i] = frame_vars[x]
else
val = x isa Symbol ? x : @lookup(fr, x)
id = record!(tape, Constant, val)
cvars[i] = id
if val != x
# if constant appeared to be a SlotNumber or SSAValue
# store its mapping into frame_vars
frame_vars[x] = id
end
end
end
return cf, cargs, cvars
end


"""
Given a Frame and current expression, extract LHS location (SlotNumber or SSAValue)
"""
get_location(fr::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc)

is_int_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call))
is_int_assign_expr(ex) = Meta.isexpr(ex, :(=)) && (isa(ex.args[2], SlotNumber) || isa(ex.args[2], SSAValue))

is_interesting_expr(ex) = is_int_call_expr(ex) || is_int_assign_expr(ex) || Meta.isexpr(ex, :return)


function itrace!(f, tape::Tape, argvars...; primitives)
args, vars = zip(argvars...)
fr = enter_call(f, args...)
frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars))
is_interesting_expr(current_expr(fr)) || step_expr!(fr) # skip non-call expressions
ex = current_expr(fr)
while !Meta.isexpr(ex, :return)
if is_int_assign_expr(ex)
lhs, rhs = ex.args
frame_vars[lhs] = frame_vars[rhs]
step_expr!(fr)
elseif is_int_call_expr(ex)
# read as "current function", "current arguments", "current variables"
cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex)
loc = get_location(fr, ex)
# there are several special cases such as NamedTuples and constructors
# we replace these with calls to special helper functions
if cf isa UnionAll && cf <: NamedTuple
# replace cf with namedtuple function, adjust arguments
names = collect(cf.body.parameters)[1]
cf = namedtuple
cargs = [names; cargs]
names_var_id = record!(tape, Constant, names)
cvars = [names_var_id; cvars]
elseif cf isa DataType
# constructor, replace with a call to __new__ which we know how to differentiate
T = cf
cf = __new__
cargs = [T; cargs]
T_var_id = record!(tape, Constant, T)
cvars = [T_var_id; cvars]
elseif cf == Base.tuple
cf = __tuple__
elseif cf == Base.getfield
# similar to constuctors, there's a special case for __getfield__ in backprop
cf = __getfield__
end
# if current function is a primitive of a built-in, write it to the tape
# otherwise recurse into the current function
if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction)
step_expr!(fr)
retval = @lookup(fr, loc)
ret_id = record!(tape, Call, retval, cf, cvars)
frame_vars[loc] = ret_id # for slots it may overwrite old mapping
else
retval, ret_id = itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives)
frame_vars[loc] = ret_id # for slots it may overwrite old mapping
step_expr!(fr) # can we avoid this double execution?
end
else
step_expr!(fr)
end
ex = current_expr(fr)
end
retval = @lookup(fr, ex.args[1])
ret_id = frame_vars[ex.args[1]]
return retval, ret_id # return var ID of a result variable
end


"""
Trace function f with arguments args using JuliaInterpreter
"""
function itrace(f, args...; primitives=PRIMITIVES, optimize=true)
tape = Tape(guess_device(args))
argvars = Vector(undef, length(args))
for (i, arg) in enumerate(args)
id = record!(tape, Input, arg)
argvars[i] = (arg, id)
end
val, resultid = itrace!(f, tape, argvars...; primitives=primitives)
tape.resultid = resultid
if optimize
tape = simplify(tape)
end
return val, tape
end
Loading

0 comments on commit fcfae2b

Please sign in to comment.