diff --git a/Manifest.toml b/Manifest.toml index c5090ae..b87bb4d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -9,10 +9,16 @@ git-tree-sha1 = "1d00c35118babf85c4a9a72c3d3550d498b64a03" repo-rev = "45a4f68c304c5ac39543338f7281812a689adf35" repo-url = "https://github.com/jrevels/Cassette.jl.git" uuid = "7057c7e9-c182-5462-911a-8362d720325c" -version = "0.2.2+" +version = "0.2.5+" + +[[CodeTracking]] +deps = ["InteractiveUtils", "Test", "UUIDs"] +git-tree-sha1 = "9b21a2dfe51ba71fdc5688039075819196595367" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "0.5.7" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Espresso]] @@ -24,9 +30,15 @@ uuid = "6912e4f1-e036-58b0-9138-08d1e6358ea9" version = "0.6.0" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JuliaInterpreter]] +deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] +git-tree-sha1 = "ed46097f465a091f6b126966015048193791743a" +uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +version = "0.6.0" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -51,6 +63,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "0.5.2" +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -68,3 +83,7 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/Project.toml b/Project.toml index 9fc015b..59fb1bb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ authors = ["Andrei Zhabinski "] [deps] Cassette = "7057c7e9-c182-5462-911a-8362d720325c" Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9" +JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/README.md b/README.md index fb0e41c..00e3b05 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ print(tape) # %3 = broadcast(%2, %1)::Array{Float64,1} # %4 = sum(%3)::Float64 ``` -`trace` uses [Cassette.jl](https://github.com/jrevels/Cassette.jl/) to collect function calls during execution. Functions are divided into 2 groups: +`trace` uses [JuliaInterpreter.jl](https://github.com/JuliaDebug/JuliaInterpreter.jl) to collect function calls during execution. Functions are divided into 2 groups: * primitive, which are recorded to the tape; * non-primitive, which are traced-through down to primitive ones. @@ -112,6 +112,8 @@ compile!(tape) # 492.063 ns (2 allocations: 144 bytes) ``` +Note that `trace()` is an alias to `itrace()` - JuliaInterpreter-based tracer. Older versions of Yota used another implementation with identical interface and capabilities, but based on [Cassette.jl](https://github.com/jrevels/Cassette.jl). This implementation is still available by name `ctrace()`. + ## CuArrays support (experimental) Yota should work with CuArrays out of the box, although integration is not well tested yet. diff --git a/src/compile.jl b/src/compile.jl index cc5128c..6b709ae 100644 --- a/src/compile.jl +++ b/src/compile.jl @@ -18,12 +18,7 @@ end function to_exnode(op::Call) arg_names = map(make_name, op.args) - # fns = maybe_to_symbol(op.fn) - if isempty(op.kwargs) - ex = Expr(:call, op.fn, arg_names...) - else - ex = Expr(:call, op.fn, Espresso.make_kw_params(op.kwargs), arg_names...) - end + ex = Expr(:call, op.fn, arg_names...) return ExNode{:call}(make_name(op), ex; val=op.val) end diff --git a/src/core.jl b/src/core.jl index bb90bcc..1247e34 100644 --- a/src/core.jl +++ b/src/core.jl @@ -12,7 +12,7 @@ include("helpers.jl") include("devices.jl") include("tape.jl") include("tapeutils.jl") -include("trace.jl") +include("trace/trace.jl") include("diffrules.jl") include("grad.jl") include("compile.jl") diff --git a/src/diffrules.jl b/src/diffrules.jl index ff1fd7c..e7443a5 100644 --- a/src/diffrules.jl +++ b/src/diffrules.jl @@ -244,7 +244,7 @@ end @diffrule tuple(x,y,z,t) z ds[3] @diffrule tuple(x,y,z,t) t ds[4] -# __tuple__ (current tracer implemetation uses it instead of normal tuple) +# __tuple__ (cassette tracer implemetation uses it instead of normal tuple) @diffrule __tuple__(x) x ds[1] @diffrule __tuple__(x,y) x ds[1] @diffrule __tuple__(x,y) y ds[2] @@ -313,28 +313,26 @@ end # binary substraction @diffrule -(x::Real, y::Real) x ds -# @diffrule -(x::Real, y::AbstractArray) x sum(ds) -# @diffrule -(x::AbstractArray, y::Real) x ones(size(x)) .* ds @diffrule -(x::AbstractArray, y::AbstractArray) x ds @diffrule -(x::Real , y::Real) y -ds -# @diffrule -(x::Real, y::AbstractArray) y -ones(size(y)) .* ds -# @diffrule -(x::AbstractArray, y::Real) y -sum(ds) @diffrule -(x::AbstractArray, y::AbstractArray) y -ds # sum() and mean() -# @diffrule sum(x::Real) x ds @diffrule sum(x::AbstractArray) x sum_grad(x, ds) @diffrule Base._sum(x::AbstractArray, y::Int) x sum_grad(x, ds) @diffrule Base._sum(x::AbstractArray, y::Int) y zero(eltype(x)) +@diffrule Core.kwfunc(sum)(_dims, _, x::AbstractArray) x sum_grad(x, ds) # special sums @diffrule sum(_fn::typeof(log), x::AbstractArray) x sum_grad(x, ds) ./ x -# @diffrule Statistics.mean(x::Real) x ds @diffrule Statistics.mean(x::AbstractArray) x mean_grad(x, ds) @diffrule Statistics._mean(x::AbstractArray, y::Int) x mean_grad(x, ds) @diffrule Statistics._mean(x::AbstractArray, y::Int) y zero(eltype(x)) +@diffrule Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) x mean_grad(x, ds) +@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _dims +@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _ # diag @diffrule diag(x::AbstractMatrix) x Diagonal(ds) diff --git a/src/grad.jl b/src/grad.jl index 3f16715..0ba1537 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -85,7 +85,6 @@ setderiv!(tape::Tape, op::AbstractOp, grad_op::AbstractOp) = (tape.derivs[op.id] Espresso.to_expr(tape::Tape, op::Call) = begin - @assert isempty(op.kwargs) "Oops, functions with kwargs aren't supported just yet" Expr(:call, op.fn, [Symbol("%$i") for i in op.args]...) end @@ -181,7 +180,7 @@ function back!(tape::Tape) end end elseif op.fn == __getfield__ - # unstructuring if tuples is lowere into pretty weird code sequence + # unstructuring of tuples is lowered into pretty weird code sequence # ending with __getfield__; similar to getproperty(), we find source var # for the corresponding tuple argument and backprop to it if haskey(tape.derivs, op.id) diff --git a/src/helpers.jl b/src/helpers.jl index 7ab5530..98068a3 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -76,3 +76,5 @@ unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y([x], y, untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds) untranspose_vec(ds::Adjoint{T, <:AbstractVector{T}}) where T = adjoint(ds) untranspose_vec(ds::AbstractMatrix) = dropdims(transpose(ds); dims=2) + +namedtuple(names, values) = NamedTuple{names}(values) diff --git a/src/tape.jl b/src/tape.jl index 78c4498..3b067c4 100644 --- a/src/tape.jl +++ b/src/tape.jl @@ -47,19 +47,16 @@ mutable struct Call <: AbstractOp val::Any fn::Union{Function, Type} args::Vector{Int} - kwargs::Dict # currently not used end -Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}; kwargs=Dict()) = - Call(id, val, fn, args, kwargs) +# Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}) = +# Call(id, val, fn, args) Base.getproperty(op::Input, f::Call) = f == :typ ? typeof(op.val) : getfield(op, f) function Base.show(io::IO, op::Call) arg_str = join(["%$(id)" for id in op.args], ", ") - kwarg_str = (isempty(op.kwargs) ? "" : "; " * - join(["$k=$v" for (k, v) in op.kwargs], ", ")) - print(io, "%$(op.id) = $(op.fn)($arg_str$kwarg_str)::$(op.typ)") + print(io, "%$(op.id) = $(op.fn)($arg_str)::$(op.typ)") end @@ -159,7 +156,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false) arg_id = record!(tape, Constant, x) new_op_args[i] = arg_id end - end + end fn = ex.args[1] fn = device_function(tape.device, fn) if bcast @@ -203,8 +200,7 @@ function replace_in_args!(tape::Tape, st::Dict) for (i, op) in enumerate(tape) if op isa Call new_args = [get(st, x, x) for x in op.args] - new_kwargs = Dict(k => get(st, x, x) for (k, x) in op.kwargs) - tape[i] = copy_with(op, args=new_args, kwargs=new_kwargs) + tape[i] = copy_with(op, args=new_args) end end end diff --git a/src/trace.jl b/src/trace/cassette.jl similarity index 85% rename from src/trace.jl rename to src/trace/cassette.jl index 4a83e64..70b98c7 100644 --- a/src/trace.jl +++ b/src/trace/cassette.jl @@ -13,23 +13,6 @@ Cassette.hastagging(::Type{<:TraceCtx}) = true # CUSTOM PASS # ######################################################################## -function __new__(T, args...) - # @show T - # @show args - # note: we also add __new__() to the list of primitives so it's not overdubbed recursively - if T <: NamedTuple - return T(args) - else - return T(args...) - end -end - - -__tuple__(args...) = tuple(args...) -__getfield__(args...) = getfield(args...) - - - is_gref_call(a, fn_name) = a isa GlobalRef && a.name == fn_name @@ -55,14 +38,6 @@ end # TRACE # ######################################################################## -const PRIMITIVES = Set([ - *, /, +, -, sin, cos, sum, Base._sum, - println, - Base.getproperty, Base.getfield, Base.indexed_iterate, # Core.kwfunc, - broadcast, Broadcast.materialize, Broadcast.broadcasted, - __new__, __tuple__, __getfield__]) - - struct TapeBox tape::Tape primitives::Set{Any} @@ -78,7 +53,7 @@ foo(x) = 2.0x + 1.0 val, tape = trace(foo, 4.0) ``` """ -function trace(f, args...; primitives=PRIMITIVES, optimize=true) +function ctrace(f, args...; primitives=PRIMITIVES, optimize=true) # create tape tape = Tape(guess_device(args)) box = TapeBox(tape, primitives) diff --git a/src/trace/interp.jl b/src/trace/interp.jl new file mode 100644 index 0000000..1a2cb48 --- /dev/null +++ b/src/trace/interp.jl @@ -0,0 +1,132 @@ +import JuliaInterpreter +import JuliaInterpreter: enter_call, step_expr!, @lookup, Frame, SSAValue, SlotNumber + + +getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc] +current_expr(fr::Frame) = getexpr(fr, fr.pc) + + +""" +Split JuliaInterpreter call expression into a tuple of 3 elements: + + * function to be called + * args to this function + * vars on the tape corresponding to these args + +If arguments include free parameters (not SlotNumber or SSAValue), these are recorded +to the tape as constants +""" +function split_int_call!(tape::Tape, fr::Frame, frame_vars::Dict, ex) + arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args + # for whatever reason JuliaInterpreter wraps some nodes in the original code into QuoteNode + arr = [isa(x, QuoteNode) ? x.value : x for x in arr] + cf = @lookup(fr, arr[1]) + cargs = [x isa Symbol ? x : @lookup(fr, x) for x in arr[2:end]] + cvars = Vector{Int}(undef, length(cargs)) + for (i, x) in enumerate(arr[2:end]) + # if isa(x, JuliaInterpreter.SlotNumber) || isa(x, JuliaInterpreter.SSAValue) + if haskey(frame_vars, x) + cvars[i] = frame_vars[x] + else + val = x isa Symbol ? x : @lookup(fr, x) + id = record!(tape, Constant, val) + cvars[i] = id + if val != x + # if constant appeared to be a SlotNumber or SSAValue + # store its mapping into frame_vars + frame_vars[x] = id + end + end + end + return cf, cargs, cvars +end + + +""" +Given a Frame and current expression, extract LHS location (SlotNumber or SSAValue) +""" +get_location(fr::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc) + +is_int_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) +is_int_assign_expr(ex) = Meta.isexpr(ex, :(=)) && (isa(ex.args[2], SlotNumber) || isa(ex.args[2], SSAValue)) + +is_interesting_expr(ex) = is_int_call_expr(ex) || is_int_assign_expr(ex) || Meta.isexpr(ex, :return) + + +function itrace!(f, tape::Tape, argvars...; primitives) + args, vars = zip(argvars...) + fr = enter_call(f, args...) + frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars)) + is_interesting_expr(current_expr(fr)) || step_expr!(fr) # skip non-call expressions + ex = current_expr(fr) + while !Meta.isexpr(ex, :return) + if is_int_assign_expr(ex) + lhs, rhs = ex.args + frame_vars[lhs] = frame_vars[rhs] + step_expr!(fr) + elseif is_int_call_expr(ex) + # read as "current function", "current arguments", "current variables" + cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex) + loc = get_location(fr, ex) + # there are several special cases such as NamedTuples and constructors + # we replace these with calls to special helper functions + if cf isa UnionAll && cf <: NamedTuple + # replace cf with namedtuple function, adjust arguments + names = collect(cf.body.parameters)[1] + cf = namedtuple + cargs = [names; cargs] + names_var_id = record!(tape, Constant, names) + cvars = [names_var_id; cvars] + elseif cf isa DataType + # constructor, replace with a call to __new__ which we know how to differentiate + T = cf + cf = __new__ + cargs = [T; cargs] + T_var_id = record!(tape, Constant, T) + cvars = [T_var_id; cvars] + elseif cf == Base.tuple + cf = __tuple__ + elseif cf == Base.getfield + # similar to constuctors, there's a special case for __getfield__ in backprop + cf = __getfield__ + end + # if current function is a primitive of a built-in, write it to the tape + # otherwise recurse into the current function + if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) + step_expr!(fr) + retval = @lookup(fr, loc) + ret_id = record!(tape, Call, retval, cf, cvars) + frame_vars[loc] = ret_id # for slots it may overwrite old mapping + else + retval, ret_id = itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives) + frame_vars[loc] = ret_id # for slots it may overwrite old mapping + step_expr!(fr) # can we avoid this double execution? + end + else + step_expr!(fr) + end + ex = current_expr(fr) + end + retval = @lookup(fr, ex.args[1]) + ret_id = frame_vars[ex.args[1]] + return retval, ret_id # return var ID of a result variable +end + + +""" +Trace function f with arguments args using JuliaInterpreter +""" +function itrace(f, args...; primitives=PRIMITIVES, optimize=true) + tape = Tape(guess_device(args)) + argvars = Vector(undef, length(args)) + for (i, arg) in enumerate(args) + id = record!(tape, Input, arg) + argvars[i] = (arg, id) + end + val, resultid = itrace!(f, tape, argvars...; primitives=primitives) + tape.resultid = resultid + if optimize + tape = simplify(tape) + end + return val, tape +end diff --git a/src/trace/trace.jl b/src/trace/trace.jl new file mode 100644 index 0000000..915b829 --- /dev/null +++ b/src/trace/trace.jl @@ -0,0 +1,35 @@ +function __new__(T, args...) + # @show T + # @show args + # note: we also add __new__() to the list of primitives so it's not overdubbed recursively + if T <: NamedTuple + return T(args) + else + return T(args...) + end +end + + +__tuple__(args...) = tuple(args...) +__getfield__(args...) = getfield(args...) + +const PRIMITIVES = Set([ + # *, /, +, -, sin, cos, sum, Base._sum, + print, println, + Base.getproperty, Base.getfield, Base.indexed_iterate, + # broadcasting + broadcast, Broadcast.materialize, Broadcast.broadcasted, + # functions with kw arguments + Core.apply_type, Core.kwfunc, + tuple, + # for loop primitives + Colon(), Base.iterate, Base.not_int, ===, + # our own special functions + __new__, __tuple__, __getfield__, namedtuple]) + + +include("cassette.jl") +include("interp.jl") + + +trace = itrace diff --git a/test/runtests.jl b/test/runtests.jl index c8d18a0..f4f1ccf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,11 @@ using Test using Yota -using Yota: Tape, Input, Call, Constant, trace, play!, transform, binarize_ops +using Yota: Tape, Input, Call, Constant, trace, itrace, play!, transform, binarize_ops using Yota: mean_grad, setfield_nested!, copy_with, simplegrad, remove_unused using Yota: find_field_source_var -include("test_tracer.jl") +include("test_trace_cassette.jl") +include("test_trace_interp.jl") include("gradcheck.jl") include("test_simple.jl") include("test_grad.jl") diff --git a/test/test_tracer.jl b/test/test_trace_cassette.jl similarity index 100% rename from test/test_tracer.jl rename to test/test_trace_cassette.jl diff --git a/test/test_trace_interp.jl b/test/test_trace_interp.jl new file mode 100644 index 0000000..7810ff7 --- /dev/null +++ b/test/test_trace_interp.jl @@ -0,0 +1,30 @@ +@testset "itracer: calls" begin + val, tape = itrace(inc_mul, 2.0, 3.0) + @test val == inc_mul(2.0, 3.0) + @test length(tape) == 5 + @test tape[3] isa Constant +end + +@testset "itracer: bcast" begin + A = rand(3) + B = rand(3) + val, tape = itrace(inc_mul, A, B) + @test val == inc_mul(A, B) + # broadcasting may be lowered to different forms, + # so making no assumptions regarding the tape + + val, tape = itrace(inc_mul2, A, B) + @test val == inc_mul2(A, B) +end + +@testset "itracer: primitives" begin + x = 3.0 + val1, tape1 = itrace(non_primitive_caller, x) + val2, tape2 = itrace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) + + @test val1 == val2 + @test any(op isa Call && op.fn == (*) for op in tape1) + @test tape2[2].fn == non_primitive + @test tape2[3].fn == sin + +end