From 1549fa82dfbb24e417502e7f21c2f3bf8b2b2613 Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Fri, 26 Jul 2019 02:16:38 +0300 Subject: [PATCH 1/6] Add basic itrace() --- Manifest.toml | 25 +++++++++-- Project.toml | 1 + src/itrace.jl | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 src/itrace.jl diff --git a/Manifest.toml b/Manifest.toml index c5090ae..b87bb4d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -9,10 +9,16 @@ git-tree-sha1 = "1d00c35118babf85c4a9a72c3d3550d498b64a03" repo-rev = "45a4f68c304c5ac39543338f7281812a689adf35" repo-url = "https://github.com/jrevels/Cassette.jl.git" uuid = "7057c7e9-c182-5462-911a-8362d720325c" -version = "0.2.2+" +version = "0.2.5+" + +[[CodeTracking]] +deps = ["InteractiveUtils", "Test", "UUIDs"] +git-tree-sha1 = "9b21a2dfe51ba71fdc5688039075819196595367" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "0.5.7" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Espresso]] @@ -24,9 +30,15 @@ uuid = "6912e4f1-e036-58b0-9138-08d1e6358ea9" version = "0.6.0" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JuliaInterpreter]] +deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] +git-tree-sha1 = "ed46097f465a091f6b126966015048193791743a" +uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +version = "0.6.0" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -51,6 +63,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "0.5.2" +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -68,3 +83,7 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/Project.toml b/Project.toml index 9fc015b..59fb1bb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ authors = ["Andrei Zhabinski "] [deps] Cassette = "7057c7e9-c182-5462-911a-8362d720325c" Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9" +JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/itrace.jl b/src/itrace.jl new file mode 100644 index 0000000..a858601 --- /dev/null +++ b/src/itrace.jl @@ -0,0 +1,115 @@ +import JuliaInterpreter +import JuliaInterpreter: enter_call, step_expr!, next_call!, @lookup, Frame + +include("core.jl") + + + +# function itrace(f, args...; primitives=PRIMITIVES, optimize=true) +# end + + +getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc] +current_expr(fr::Frame) = getexpr(fr, fr.pc) + + + +# if f in primitives +# # args = with_tagged_properties(ctx, tape, args) # only if f() is getproperty() +# args = with_free_args_as_constants(ctx, tape, args) +# arg_ids = [metadata(x, ctx) for x in args] +# arg_ids = Int[id isa Cassette.NoMetaData ? -1 : id for id in arg_ids] +# # execute call +# retval = fallback(ctx, f, [untag(x, ctx) for x in args]...) +# # record to the tape and tag with a newly created ID +# ret_id = record!(tape, Call, retval, f, arg_ids) +# retval = tag(retval, ctx, ret_id) +# elseif canrecurse(ctx, f, args...) +# retval = Cassette.recurse(ctx, f, args...) +# else +# retval = fallback(ctx, f, args...) +# end + +# iscall(ex) = || () + + +# """ +# Split JuliaInterpreter call expression into a tuple of 3 elements: + +# * function to be called +# * args to this function +# * vars on the tape corresponding to these args +# """ +# function split_int_call(tape::Tape, fr::Frame, ex) +# arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args +# cf = @lookup(fr, arr[1]) +# cargs = [@lookup(fr, a) for a in arr[2:end]] +# cvars = +# if +# f_args = +# else +# f_args = [@lookup(fr, a) for a in ] +# end +# return f_args[1], f_args[2:end] +# end + + + +function itrace!(f, tape::Tape, argvars...; primitives) + args, vars = zip(argvars...) + fr = enter_call(f, args...) + frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars)) + ex = current_expr(fr) + while !Meta.isexpr(ex, :return) + if Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) + arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args # TODO: move or simplify + # cf, cargs, cvars = function, args and vars of the current expression + cf = @lookup(fr, arr[1]) + cargs = [@lookup(fr, a) for a in arr[2:end]] + cvars = [frame_vars[wa] for wa in arr[2:end]] + if cf in primitives + # we will map result to this location (SlotNumber or SSAValue) + loc = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc) # TODO: check + retval = next_call!(fr) + ret_id = record!(tape, Call, retval, cf, cvars) + frame_vars[loc] = ret_id # for slots, may overwrite old mapping + else + # TODO: handle recursive call + itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives) + end + end + ex = current_expr(fr) + end + # TODO: handle return +end + + + +bar(x) = 2.0x + 1.0 + +function foo(x) + y = bar(x) + z = exp(y) +end + + +function itrace(f, args...; primitives=PRIMITIVES, optimize=true) + tape = Tape(guess_device(args)) + argvars = Vector(undef, length(args)) + for (i, arg) in enumerate(args) + id = record!(tape, Input, arg) + argvars[i] = (arg, id) + end + itrace!(f, tape, argvars...; primitives=primitives) +end + + +# NEXT STEPS: +# add bar() to primitives and finish non-recursive path of itrace! + +function main() + f = foo + args = (4.0,) + primitives = PRIMITIVES + _itrace(f, args, ; primitives=primitives) +end From cee2d7051477b6e5d52b25560634cbd2e7f970dd Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Sat, 27 Jul 2019 15:08:30 +0300 Subject: [PATCH 2/6] Make itrace() work for functions with keyword arguments --- src/itrace.jl | 142 ++++++++++++++++++++++++++------------------------ src/trace.jl | 2 + 2 files changed, 75 insertions(+), 69 deletions(-) diff --git a/src/itrace.jl b/src/itrace.jl index a858601..977fda7 100644 --- a/src/itrace.jl +++ b/src/itrace.jl @@ -4,112 +4,116 @@ import JuliaInterpreter: enter_call, step_expr!, next_call!, @lookup, Frame include("core.jl") - -# function itrace(f, args...; primitives=PRIMITIVES, optimize=true) -# end - - getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc] current_expr(fr::Frame) = getexpr(fr, fr.pc) +""" +Split JuliaInterpreter call expression into a tuple of 3 elements: + + * function to be called + * args to this function + * vars on the tape corresponding to these args + +If arguments include free parameters (not SlotNumber or SSAValue), these are recorded +to the tape as constants +""" +function split_int_call!(tape::Tape, fr::Frame, frame_vars::Dict, ex) + arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args + # for whatever reason JuliaInterpreter wraps some nodes in the original code into QuoteNode + arr = [isa(x, QuoteNode) ? x.value : x for x in arr] + cf = @lookup(fr, arr[1]) + cargs = [@lookup(fr, x) for x in arr[2:end]] + cvars = Vector{Int}(undef, length(cargs)) + for (i, x) in enumerate(arr[2:end]) + # if isa(x, JuliaInterpreter.SlotNumber) || isa(x, JuliaInterpreter.SSAValue) + if haskey(frame_vars, x) + cvars[i] = frame_vars[x] + else + val = @lookup(fr, x) + id = record!(tape, Constant, val) + cvars[i] = id + if val != x + # if constant appeared to be a SlotNumber or SSAValue + # store its mapping into frame_vars + frame_vars[x] = id + end + end + end + return cf, cargs, cvars +end + -# if f in primitives -# # args = with_tagged_properties(ctx, tape, args) # only if f() is getproperty() -# args = with_free_args_as_constants(ctx, tape, args) -# arg_ids = [metadata(x, ctx) for x in args] -# arg_ids = Int[id isa Cassette.NoMetaData ? -1 : id for id in arg_ids] -# # execute call -# retval = fallback(ctx, f, [untag(x, ctx) for x in args]...) -# # record to the tape and tag with a newly created ID -# ret_id = record!(tape, Call, retval, f, arg_ids) -# retval = tag(retval, ctx, ret_id) -# elseif canrecurse(ctx, f, args...) -# retval = Cassette.recurse(ctx, f, args...) -# else -# retval = fallback(ctx, f, args...) -# end - -# iscall(ex) = || () - - -# """ -# Split JuliaInterpreter call expression into a tuple of 3 elements: - -# * function to be called -# * args to this function -# * vars on the tape corresponding to these args -# """ -# function split_int_call(tape::Tape, fr::Frame, ex) -# arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args -# cf = @lookup(fr, arr[1]) -# cargs = [@lookup(fr, a) for a in arr[2:end]] -# cvars = -# if -# f_args = -# else -# f_args = [@lookup(fr, a) for a in ] -# end -# return f_args[1], f_args[2:end] -# end +""" +Given a Frame and current expression, extract LHS location (SlotNumber or SSAValue) +""" +get_location(fr::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc) +is_iint_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) function itrace!(f, tape::Tape, argvars...; primitives) args, vars = zip(argvars...) fr = enter_call(f, args...) - frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars)) + frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars)) + is_int_call_expr(current_expr(fr)) || next_call!(fr) # skip non-call expressions ex = current_expr(fr) while !Meta.isexpr(ex, :return) - if Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) - arr = Meta.isexpr(ex, :(=)) ? ex.args[2].args : ex.args # TODO: move or simplify - # cf, cargs, cvars = function, args and vars of the current expression - cf = @lookup(fr, arr[1]) - cargs = [@lookup(fr, a) for a in arr[2:end]] - cvars = [frame_vars[wa] for wa in arr[2:end]] - if cf in primitives - # we will map result to this location (SlotNumber or SSAValue) - loc = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc) # TODO: check - retval = next_call!(fr) - ret_id = record!(tape, Call, retval, cf, cvars) - frame_vars[loc] = ret_id # for slots, may overwrite old mapping + # println("--------------- $ex -------------") + if is_int_call_expr(ex) + cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex) + loc = get_location(fr, ex) + if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) || !isa(cf, Function) + next_call!(fr) + retval = @lookup(fr, loc) + ret_id = record!(tape, Call, retval, cf, cvars) + frame_vars[loc] = ret_id # for slots it may overwrite old mapping else - # TODO: handle recursive call - itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives) + retval, ret_id = itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives) + frame_vars[loc] = ret_id # for slots it may overwrite old mapping + next_call!(fr) # can we avoid this double execution? end end ex = current_expr(fr) end - # TODO: handle return + retval = @lookup(fr, ex.args[1]) + ret_id = frame_vars[ex.args[1]] + return retval, ret_id # return var ID of a result variable end bar(x) = 2.0x + 1.0 +baz(x; y=3.5) = x - y -function foo(x) - y = bar(x) - z = exp(y) -end +function foo(a) + b = bar(a) + c = baz(b; y=4.5) + d = exp(c) + return d +end +""" +Trace function f with arguments args using JuliaInterpreter +""" function itrace(f, args...; primitives=PRIMITIVES, optimize=true) tape = Tape(guess_device(args)) argvars = Vector(undef, length(args)) - for (i, arg) in enumerate(args) + for (i, arg) in enumerate(args) id = record!(tape, Input, arg) argvars[i] = (arg, id) end - itrace!(f, tape, argvars...; primitives=primitives) + val, resultid = itrace!(f, tape, argvars...; primitives=primitives) + tape.resultid = resultid + return val, tape end -# NEXT STEPS: -# add bar() to primitives and finish non-recursive path of itrace! - function main() f = foo args = (4.0,) primitives = PRIMITIVES - _itrace(f, args, ; primitives=primitives) + push!(primitives, Core.kwfunc(baz)) + _, tape = itrace(f, args...; primitives=primitives) end diff --git a/src/trace.jl b/src/trace.jl index 4a83e64..ae2194e 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -60,6 +60,8 @@ const PRIMITIVES = Set([ println, Base.getproperty, Base.getfield, Base.indexed_iterate, # Core.kwfunc, broadcast, Broadcast.materialize, Broadcast.broadcasted, + Core.apply_type, Core.kwfunc, + tuple, __new__, __tuple__, __getfield__]) From 704b0d49d8c8b16b6c4f8ed63378a43b6af62bd6 Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Thu, 8 Aug 2019 13:59:40 +0300 Subject: [PATCH 3/6] Change default tracer to itrace() --- src/core.jl | 1 + src/grad.jl | 2 +- src/itrace.jl | 28 +++++----------------------- test/runtests.jl | 3 ++- test/test_itracer.jl | 30 ++++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 25 deletions(-) create mode 100644 test/test_itracer.jl diff --git a/src/core.jl b/src/core.jl index bb90bcc..763b448 100644 --- a/src/core.jl +++ b/src/core.jl @@ -13,6 +13,7 @@ include("devices.jl") include("tape.jl") include("tapeutils.jl") include("trace.jl") +include("itrace.jl") include("diffrules.jl") include("grad.jl") include("compile.jl") diff --git a/src/grad.jl b/src/grad.jl index 3f16715..8620887 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -249,7 +249,7 @@ end function _grad(f::Function, args...) - val, tape = trace(f, args...) + val, tape = itrace(f, args...) # calculate gradients tape = _grad(tape) # construct GradResult object that wraps tape and provides accessors for computed derivatives diff --git a/src/itrace.jl b/src/itrace.jl index 977fda7..1c57ba9 100644 --- a/src/itrace.jl +++ b/src/itrace.jl @@ -1,8 +1,6 @@ import JuliaInterpreter import JuliaInterpreter: enter_call, step_expr!, next_call!, @lookup, Frame -include("core.jl") - getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc] current_expr(fr::Frame) = getexpr(fr, fr.pc) @@ -49,7 +47,7 @@ Given a Frame and current expression, extract LHS location (SlotNumber or SSAVal """ get_location(fr::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc) -is_iint_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) +is_int_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) function itrace!(f, tape::Tape, argvars...; primitives) @@ -82,18 +80,6 @@ function itrace!(f, tape::Tape, argvars...; primitives) end - -bar(x) = 2.0x + 1.0 -baz(x; y=3.5) = x - y - - -function foo(a) - b = bar(a) - c = baz(b; y=4.5) - d = exp(c) - return d -end - """ Trace function f with arguments args using JuliaInterpreter """ @@ -106,14 +92,10 @@ function itrace(f, args...; primitives=PRIMITIVES, optimize=true) end val, resultid = itrace!(f, tape, argvars...; primitives=primitives) tape.resultid = resultid + if optimize + tape = simplify(tape) + end return val, tape end - -function main() - f = foo - args = (4.0,) - primitives = PRIMITIVES - push!(primitives, Core.kwfunc(baz)) - _, tape = itrace(f, args...; primitives=primitives) -end +# TODO: remove Call.kwargs diff --git a/test/runtests.jl b/test/runtests.jl index c8d18a0..09ca3ac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,11 @@ using Test using Yota -using Yota: Tape, Input, Call, Constant, trace, play!, transform, binarize_ops +using Yota: Tape, Input, Call, Constant, trace, itrace, play!, transform, binarize_ops using Yota: mean_grad, setfield_nested!, copy_with, simplegrad, remove_unused using Yota: find_field_source_var include("test_tracer.jl") +include("test_itracer.jl") include("gradcheck.jl") include("test_simple.jl") include("test_grad.jl") diff --git a/test/test_itracer.jl b/test/test_itracer.jl new file mode 100644 index 0000000..7810ff7 --- /dev/null +++ b/test/test_itracer.jl @@ -0,0 +1,30 @@ +@testset "itracer: calls" begin + val, tape = itrace(inc_mul, 2.0, 3.0) + @test val == inc_mul(2.0, 3.0) + @test length(tape) == 5 + @test tape[3] isa Constant +end + +@testset "itracer: bcast" begin + A = rand(3) + B = rand(3) + val, tape = itrace(inc_mul, A, B) + @test val == inc_mul(A, B) + # broadcasting may be lowered to different forms, + # so making no assumptions regarding the tape + + val, tape = itrace(inc_mul2, A, B) + @test val == inc_mul2(A, B) +end + +@testset "itracer: primitives" begin + x = 3.0 + val1, tape1 = itrace(non_primitive_caller, x) + val2, tape2 = itrace(non_primitive_caller, x; primitives=Set([non_primitive, sin])) + + @test val1 == val2 + @test any(op isa Call && op.fn == (*) for op in tape1) + @test tape2[2].fn == non_primitive + @test tape2[3].fn == sin + +end From 984c0cea70d2c10a7598966b0445dc4bdce94210 Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Sat, 10 Aug 2019 23:06:49 +0300 Subject: [PATCH 4/6] Fix some of the errors with keyword parameters --- src/compile.jl | 7 +------ src/helpers.jl | 2 ++ src/itrace.jl | 12 ++++++++++-- src/tape.jl | 1 - src/trace.jl | 3 ++- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/compile.jl b/src/compile.jl index cc5128c..6b709ae 100644 --- a/src/compile.jl +++ b/src/compile.jl @@ -18,12 +18,7 @@ end function to_exnode(op::Call) arg_names = map(make_name, op.args) - # fns = maybe_to_symbol(op.fn) - if isempty(op.kwargs) - ex = Expr(:call, op.fn, arg_names...) - else - ex = Expr(:call, op.fn, Espresso.make_kw_params(op.kwargs), arg_names...) - end + ex = Expr(:call, op.fn, arg_names...) return ExNode{:call}(make_name(op), ex; val=op.val) end diff --git a/src/helpers.jl b/src/helpers.jl index 7ab5530..98068a3 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -76,3 +76,5 @@ unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y([x], y, untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds) untranspose_vec(ds::Adjoint{T, <:AbstractVector{T}}) where T = adjoint(ds) untranspose_vec(ds::AbstractMatrix) = dropdims(transpose(ds); dims=2) + +namedtuple(names, values) = NamedTuple{names}(values) diff --git a/src/itrace.jl b/src/itrace.jl index 1c57ba9..67502ce 100644 --- a/src/itrace.jl +++ b/src/itrace.jl @@ -1,5 +1,5 @@ import JuliaInterpreter -import JuliaInterpreter: enter_call, step_expr!, next_call!, @lookup, Frame +import JuliaInterpreter: enter_call, step_expr!, next_call!, @lookup, Frame, SSAValue, SlotNumber getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc] @@ -61,7 +61,15 @@ function itrace!(f, tape::Tape, argvars...; primitives) if is_int_call_expr(ex) cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex) loc = get_location(fr, ex) - if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) || !isa(cf, Function) + if cf isa UnionAll && cf <: NamedTuple + # replace cf with namedtuple function, adjust arguments + names = collect(cf.body.parameters)[1] + cf = namedtuple + cargs = [names; cargs] + names_var_id = record!(tape, Constant, names) + cvars = [names_var_id; cvars] + end + if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) next_call!(fr) retval = @lookup(fr, loc) ret_id = record!(tape, Call, retval, cf, cvars) diff --git a/src/tape.jl b/src/tape.jl index 78c4498..869def8 100644 --- a/src/tape.jl +++ b/src/tape.jl @@ -47,7 +47,6 @@ mutable struct Call <: AbstractOp val::Any fn::Union{Function, Type} args::Vector{Int} - kwargs::Dict # currently not used end Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}; kwargs=Dict()) = diff --git a/src/trace.jl b/src/trace.jl index ae2194e..7c34c45 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -62,7 +62,8 @@ const PRIMITIVES = Set([ broadcast, Broadcast.materialize, Broadcast.broadcasted, Core.apply_type, Core.kwfunc, tuple, - __new__, __tuple__, __getfield__]) + __new__, __tuple__, __getfield__, + namedtuple]) struct TapeBox From 6015e04bec0eb7aaa0a5892c27876dec2ee38bef Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Mon, 12 Aug 2019 01:56:56 +0300 Subject: [PATCH 5/6] Make all tests pass with itrace() --- src/diffrules.jl | 12 +++++------- src/grad.jl | 3 +-- src/itrace.jl | 23 +++++++++++++++++++---- src/tape.jl | 13 +++++-------- 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/diffrules.jl b/src/diffrules.jl index ff1fd7c..e7443a5 100644 --- a/src/diffrules.jl +++ b/src/diffrules.jl @@ -244,7 +244,7 @@ end @diffrule tuple(x,y,z,t) z ds[3] @diffrule tuple(x,y,z,t) t ds[4] -# __tuple__ (current tracer implemetation uses it instead of normal tuple) +# __tuple__ (cassette tracer implemetation uses it instead of normal tuple) @diffrule __tuple__(x) x ds[1] @diffrule __tuple__(x,y) x ds[1] @diffrule __tuple__(x,y) y ds[2] @@ -313,28 +313,26 @@ end # binary substraction @diffrule -(x::Real, y::Real) x ds -# @diffrule -(x::Real, y::AbstractArray) x sum(ds) -# @diffrule -(x::AbstractArray, y::Real) x ones(size(x)) .* ds @diffrule -(x::AbstractArray, y::AbstractArray) x ds @diffrule -(x::Real , y::Real) y -ds -# @diffrule -(x::Real, y::AbstractArray) y -ones(size(y)) .* ds -# @diffrule -(x::AbstractArray, y::Real) y -sum(ds) @diffrule -(x::AbstractArray, y::AbstractArray) y -ds # sum() and mean() -# @diffrule sum(x::Real) x ds @diffrule sum(x::AbstractArray) x sum_grad(x, ds) @diffrule Base._sum(x::AbstractArray, y::Int) x sum_grad(x, ds) @diffrule Base._sum(x::AbstractArray, y::Int) y zero(eltype(x)) +@diffrule Core.kwfunc(sum)(_dims, _, x::AbstractArray) x sum_grad(x, ds) # special sums @diffrule sum(_fn::typeof(log), x::AbstractArray) x sum_grad(x, ds) ./ x -# @diffrule Statistics.mean(x::Real) x ds @diffrule Statistics.mean(x::AbstractArray) x mean_grad(x, ds) @diffrule Statistics._mean(x::AbstractArray, y::Int) x mean_grad(x, ds) @diffrule Statistics._mean(x::AbstractArray, y::Int) y zero(eltype(x)) +@diffrule Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) x mean_grad(x, ds) +@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _dims +@nodiff Core.kwfunc(Statistics.mean)(_dims, _, x::AbstractArray) _ # diag @diffrule diag(x::AbstractMatrix) x Diagonal(ds) diff --git a/src/grad.jl b/src/grad.jl index 8620887..d654b33 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -85,7 +85,6 @@ setderiv!(tape::Tape, op::AbstractOp, grad_op::AbstractOp) = (tape.derivs[op.id] Espresso.to_expr(tape::Tape, op::Call) = begin - @assert isempty(op.kwargs) "Oops, functions with kwargs aren't supported just yet" Expr(:call, op.fn, [Symbol("%$i") for i in op.args]...) end @@ -181,7 +180,7 @@ function back!(tape::Tape) end end elseif op.fn == __getfield__ - # unstructuring if tuples is lowere into pretty weird code sequence + # unstructuring of tuples is lowered into pretty weird code sequence # ending with __getfield__; similar to getproperty(), we find source var # for the corresponding tuple argument and backprop to it if haskey(tape.derivs, op.id) diff --git a/src/itrace.jl b/src/itrace.jl index 67502ce..9994cc7 100644 --- a/src/itrace.jl +++ b/src/itrace.jl @@ -21,14 +21,14 @@ function split_int_call!(tape::Tape, fr::Frame, frame_vars::Dict, ex) # for whatever reason JuliaInterpreter wraps some nodes in the original code into QuoteNode arr = [isa(x, QuoteNode) ? x.value : x for x in arr] cf = @lookup(fr, arr[1]) - cargs = [@lookup(fr, x) for x in arr[2:end]] + cargs = [x isa Symbol ? x : @lookup(fr, x) for x in arr[2:end]] cvars = Vector{Int}(undef, length(cargs)) for (i, x) in enumerate(arr[2:end]) # if isa(x, JuliaInterpreter.SlotNumber) || isa(x, JuliaInterpreter.SSAValue) if haskey(frame_vars, x) cvars[i] = frame_vars[x] else - val = @lookup(fr, x) + val = x isa Symbol ? x : @lookup(fr, x) id = record!(tape, Constant, val) cvars[i] = id if val != x @@ -59,8 +59,11 @@ function itrace!(f, tape::Tape, argvars...; primitives) while !Meta.isexpr(ex, :return) # println("--------------- $ex -------------") if is_int_call_expr(ex) + # read as "current function", "current arguments", "current variables" cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex) loc = get_location(fr, ex) + # there are several special cases such as NamedTuples and constructors + # we replace these with calls to special helper functions if cf isa UnionAll && cf <: NamedTuple # replace cf with namedtuple function, adjust arguments names = collect(cf.body.parameters)[1] @@ -68,7 +71,21 @@ function itrace!(f, tape::Tape, argvars...; primitives) cargs = [names; cargs] names_var_id = record!(tape, Constant, names) cvars = [names_var_id; cvars] + elseif cf isa DataType + # constructor, replace with a call to __new__ which we know how to differentiate + T = cf + cf = __new__ + cargs = [T; cargs] + T_var_id = record!(tape, Constant, T) + cvars = [T_var_id; cvars] + elseif cf == Base.tuple + cf = __tuple__ + elseif cf == Base.getfield + # similar to constuctors, there's a special case for __getfield__ in backprop + cf = __getfield__ end + # if current function is a primitive of a built-in, write it to the tape + # otherwise recurse into the current function if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) next_call!(fr) retval = @lookup(fr, loc) @@ -105,5 +122,3 @@ function itrace(f, args...; primitives=PRIMITIVES, optimize=true) end return val, tape end - -# TODO: remove Call.kwargs diff --git a/src/tape.jl b/src/tape.jl index 869def8..3b067c4 100644 --- a/src/tape.jl +++ b/src/tape.jl @@ -49,16 +49,14 @@ mutable struct Call <: AbstractOp args::Vector{Int} end -Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}; kwargs=Dict()) = - Call(id, val, fn, args, kwargs) +# Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}) = +# Call(id, val, fn, args) Base.getproperty(op::Input, f::Call) = f == :typ ? typeof(op.val) : getfield(op, f) function Base.show(io::IO, op::Call) arg_str = join(["%$(id)" for id in op.args], ", ") - kwarg_str = (isempty(op.kwargs) ? "" : "; " * - join(["$k=$v" for (k, v) in op.kwargs], ", ")) - print(io, "%$(op.id) = $(op.fn)($arg_str$kwarg_str)::$(op.typ)") + print(io, "%$(op.id) = $(op.fn)($arg_str)::$(op.typ)") end @@ -158,7 +156,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false) arg_id = record!(tape, Constant, x) new_op_args[i] = arg_id end - end + end fn = ex.args[1] fn = device_function(tape.device, fn) if bcast @@ -202,8 +200,7 @@ function replace_in_args!(tape::Tape, st::Dict) for (i, op) in enumerate(tape) if op isa Call new_args = [get(st, x, x) for x in op.args] - new_kwargs = Dict(k => get(st, x, x) for (k, x) in op.kwargs) - tape[i] = copy_with(op, args=new_args, kwargs=new_kwargs) + tape[i] = copy_with(op, args=new_args) end end end From 76a87ceb0f3b723b84c07183113821ecba214c45 Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Wed, 14 Aug 2019 00:19:16 +0300 Subject: [PATCH 6/6] Make itrace() the default tracer --- README.md | 4 ++- src/core.jl | 3 +- src/grad.jl | 2 +- src/{trace.jl => trace/cassette.jl} | 30 +--------------- src/{itrace.jl => trace/interp.jl} | 20 +++++++---- src/trace/trace.jl | 35 +++++++++++++++++++ test/runtests.jl | 4 +-- ...{test_tracer.jl => test_trace_cassette.jl} | 0 .../{test_itracer.jl => test_trace_interp.jl} | 0 9 files changed, 57 insertions(+), 41 deletions(-) rename src/{trace.jl => trace/cassette.jl} (84%) rename src/{itrace.jl => trace/interp.jl} (86%) create mode 100644 src/trace/trace.jl rename test/{test_tracer.jl => test_trace_cassette.jl} (100%) rename test/{test_itracer.jl => test_trace_interp.jl} (100%) diff --git a/README.md b/README.md index fb0e41c..00e3b05 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ print(tape) # %3 = broadcast(%2, %1)::Array{Float64,1} # %4 = sum(%3)::Float64 ``` -`trace` uses [Cassette.jl](https://github.com/jrevels/Cassette.jl/) to collect function calls during execution. Functions are divided into 2 groups: +`trace` uses [JuliaInterpreter.jl](https://github.com/JuliaDebug/JuliaInterpreter.jl) to collect function calls during execution. Functions are divided into 2 groups: * primitive, which are recorded to the tape; * non-primitive, which are traced-through down to primitive ones. @@ -112,6 +112,8 @@ compile!(tape) # 492.063 ns (2 allocations: 144 bytes) ``` +Note that `trace()` is an alias to `itrace()` - JuliaInterpreter-based tracer. Older versions of Yota used another implementation with identical interface and capabilities, but based on [Cassette.jl](https://github.com/jrevels/Cassette.jl). This implementation is still available by name `ctrace()`. + ## CuArrays support (experimental) Yota should work with CuArrays out of the box, although integration is not well tested yet. diff --git a/src/core.jl b/src/core.jl index 763b448..1247e34 100644 --- a/src/core.jl +++ b/src/core.jl @@ -12,8 +12,7 @@ include("helpers.jl") include("devices.jl") include("tape.jl") include("tapeutils.jl") -include("trace.jl") -include("itrace.jl") +include("trace/trace.jl") include("diffrules.jl") include("grad.jl") include("compile.jl") diff --git a/src/grad.jl b/src/grad.jl index d654b33..0ba1537 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -248,7 +248,7 @@ end function _grad(f::Function, args...) - val, tape = itrace(f, args...) + val, tape = trace(f, args...) # calculate gradients tape = _grad(tape) # construct GradResult object that wraps tape and provides accessors for computed derivatives diff --git a/src/trace.jl b/src/trace/cassette.jl similarity index 84% rename from src/trace.jl rename to src/trace/cassette.jl index 7c34c45..70b98c7 100644 --- a/src/trace.jl +++ b/src/trace/cassette.jl @@ -13,23 +13,6 @@ Cassette.hastagging(::Type{<:TraceCtx}) = true # CUSTOM PASS # ######################################################################## -function __new__(T, args...) - # @show T - # @show args - # note: we also add __new__() to the list of primitives so it's not overdubbed recursively - if T <: NamedTuple - return T(args) - else - return T(args...) - end -end - - -__tuple__(args...) = tuple(args...) -__getfield__(args...) = getfield(args...) - - - is_gref_call(a, fn_name) = a isa GlobalRef && a.name == fn_name @@ -55,17 +38,6 @@ end # TRACE # ######################################################################## -const PRIMITIVES = Set([ - *, /, +, -, sin, cos, sum, Base._sum, - println, - Base.getproperty, Base.getfield, Base.indexed_iterate, # Core.kwfunc, - broadcast, Broadcast.materialize, Broadcast.broadcasted, - Core.apply_type, Core.kwfunc, - tuple, - __new__, __tuple__, __getfield__, - namedtuple]) - - struct TapeBox tape::Tape primitives::Set{Any} @@ -81,7 +53,7 @@ foo(x) = 2.0x + 1.0 val, tape = trace(foo, 4.0) ``` """ -function trace(f, args...; primitives=PRIMITIVES, optimize=true) +function ctrace(f, args...; primitives=PRIMITIVES, optimize=true) # create tape tape = Tape(guess_device(args)) box = TapeBox(tape, primitives) diff --git a/src/itrace.jl b/src/trace/interp.jl similarity index 86% rename from src/itrace.jl rename to src/trace/interp.jl index 9994cc7..1a2cb48 100644 --- a/src/itrace.jl +++ b/src/trace/interp.jl @@ -1,5 +1,5 @@ import JuliaInterpreter -import JuliaInterpreter: enter_call, step_expr!, next_call!, @lookup, Frame, SSAValue, SlotNumber +import JuliaInterpreter: enter_call, step_expr!, @lookup, Frame, SSAValue, SlotNumber getexpr(fr::Frame, pc::Int) = fr.framecode.src.code[pc] @@ -48,17 +48,23 @@ Given a Frame and current expression, extract LHS location (SlotNumber or SSAVal get_location(fr::Frame, ex) = Meta.isexpr(ex, :(=)) ? ex.args[1] : JuliaInterpreter.SSAValue(fr.pc) is_int_call_expr(ex) = Meta.isexpr(ex, :call) || (Meta.isexpr(ex, :(=)) && Meta.isexpr(ex.args[2], :call)) +is_int_assign_expr(ex) = Meta.isexpr(ex, :(=)) && (isa(ex.args[2], SlotNumber) || isa(ex.args[2], SSAValue)) + +is_interesting_expr(ex) = is_int_call_expr(ex) || is_int_assign_expr(ex) || Meta.isexpr(ex, :return) function itrace!(f, tape::Tape, argvars...; primitives) args, vars = zip(argvars...) fr = enter_call(f, args...) frame_vars = Dict{Any, Int}(JuliaInterpreter.SlotNumber(i + 1) => v for (i, v) in enumerate(vars)) - is_int_call_expr(current_expr(fr)) || next_call!(fr) # skip non-call expressions + is_interesting_expr(current_expr(fr)) || step_expr!(fr) # skip non-call expressions ex = current_expr(fr) while !Meta.isexpr(ex, :return) - # println("--------------- $ex -------------") - if is_int_call_expr(ex) + if is_int_assign_expr(ex) + lhs, rhs = ex.args + frame_vars[lhs] = frame_vars[rhs] + step_expr!(fr) + elseif is_int_call_expr(ex) # read as "current function", "current arguments", "current variables" cf, cargs, cvars = split_int_call!(tape, fr, frame_vars, ex) loc = get_location(fr, ex) @@ -87,15 +93,17 @@ function itrace!(f, tape::Tape, argvars...; primitives) # if current function is a primitive of a built-in, write it to the tape # otherwise recurse into the current function if cf in primitives || isa(cf, Core.Builtin) || isa(cf, Core.IntrinsicFunction) - next_call!(fr) + step_expr!(fr) retval = @lookup(fr, loc) ret_id = record!(tape, Call, retval, cf, cvars) frame_vars[loc] = ret_id # for slots it may overwrite old mapping else retval, ret_id = itrace!(cf, tape, zip(cargs, cvars)...; primitives=primitives) frame_vars[loc] = ret_id # for slots it may overwrite old mapping - next_call!(fr) # can we avoid this double execution? + step_expr!(fr) # can we avoid this double execution? end + else + step_expr!(fr) end ex = current_expr(fr) end diff --git a/src/trace/trace.jl b/src/trace/trace.jl new file mode 100644 index 0000000..915b829 --- /dev/null +++ b/src/trace/trace.jl @@ -0,0 +1,35 @@ +function __new__(T, args...) + # @show T + # @show args + # note: we also add __new__() to the list of primitives so it's not overdubbed recursively + if T <: NamedTuple + return T(args) + else + return T(args...) + end +end + + +__tuple__(args...) = tuple(args...) +__getfield__(args...) = getfield(args...) + +const PRIMITIVES = Set([ + # *, /, +, -, sin, cos, sum, Base._sum, + print, println, + Base.getproperty, Base.getfield, Base.indexed_iterate, + # broadcasting + broadcast, Broadcast.materialize, Broadcast.broadcasted, + # functions with kw arguments + Core.apply_type, Core.kwfunc, + tuple, + # for loop primitives + Colon(), Base.iterate, Base.not_int, ===, + # our own special functions + __new__, __tuple__, __getfield__, namedtuple]) + + +include("cassette.jl") +include("interp.jl") + + +trace = itrace diff --git a/test/runtests.jl b/test/runtests.jl index 09ca3ac..f4f1ccf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,8 +4,8 @@ using Yota: Tape, Input, Call, Constant, trace, itrace, play!, transform, binari using Yota: mean_grad, setfield_nested!, copy_with, simplegrad, remove_unused using Yota: find_field_source_var -include("test_tracer.jl") -include("test_itracer.jl") +include("test_trace_cassette.jl") +include("test_trace_interp.jl") include("gradcheck.jl") include("test_simple.jl") include("test_grad.jl") diff --git a/test/test_tracer.jl b/test/test_trace_cassette.jl similarity index 100% rename from test/test_tracer.jl rename to test/test_trace_cassette.jl diff --git a/test/test_itracer.jl b/test/test_trace_interp.jl similarity index 100% rename from test/test_itracer.jl rename to test/test_trace_interp.jl