Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"argextype only works on argument-position values" on Julia 1.5 #78

Open
dfdx opened this issue Aug 10, 2020 · 0 comments
Open

"argextype only works on argument-position values" on Julia 1.5 #78

dfdx opened this issue Aug 10, 2020 · 0 comments

Comments

@dfdx
Copy link
Contributor

dfdx commented Aug 10, 2020

Here's a @dynamo I use to trace code execution (full MWE below):

@dynamo function (t::IRTracer)(fargs...)
    ir = IR(fargs...)
    ir == nothing && return   # intrinsic functions
    for (v, st) in ir
        ex = st.expr
        if Meta.isexpr(ex, :call)
            ir[v] = Expr(:call, record_or_recurse!, self, ex.args...)
        else            
            ir[v] = Expr(:call, identity, ex)
        end
    end
    return ir
end

Essentially, it does 2 things:

  1. Replaces all calls to f with record_or_recurse!(..., f, ...), which either records f to the list of operations or recursively applies transformation to f.
  2. Replaces all other statements (i.g. constants) with call to identity function.

When I apply this code to function f = x -> sum(x; dims=1) (or any other function with keywords), the following error is printed (although code runs fine and returns correct result):

Internal error: encountered unexpected error in runtime:
AssertionError(msg="argextype only works on argument-position values")
argextype at ./compiler/utilities.jl:166
argextype at ./compiler/utilities.jl:158 [inlined]
call_sig at ./compiler/ssair/inlining.jl:882
process_simple! at ./compiler/ssair/inlining.jl:956
assemble_inline_todo! at ./compiler/ssair/inlining.jl:999
ssa_inlining_pass! at ./compiler/ssair/inlining.jl:74 [inlined]
run_passes at ./compiler/ssair/driver.jl:138
optimize at ./compiler/optimize.jl:174
typeinf at ./compiler/typeinfer.jl:33
typeinf_edge at ./compiler/typeinfer.jl:484
...

Original function IR (@code_ir f(x)):

1: (%1, %2)
  %3 = (:dims,)
  %4 = Core.apply_type(Core.NamedTuple, %3)
  %5 = Core.tuple(1)
  %6 = (%4)(%5)
  %7 = Core.kwfunc(Main.sum)
  %8 = (%7)(%6, Main.sum, %2)
  return %8

Transformed IR (@code_ir t(f, x), where t::IRTracer):

1: (%1, %2)
  %3 = Base.getfield(%2, 1)
  %4 = Base.getfield(%2, 2)
  %5 = (identity)((:dims,))
  %6 = (record_or_recurse!)(%1, Core.apply_type, Core.NamedTuple, %5)
  %7 = (record_or_recurse!)(%1, Core.tuple, 1)
  %8 = (record_or_recurse!)(%1, %6, %7)
  %9 = (record_or_recurse!)(%1, Core.kwfunc, Main.sum)
  %10 = (record_or_recurse!)(%1, %9, %8, Main.sum, %4)
  return %10

If I comment out any of the transformations above (either function calls, or constants), error disappears.

My best guess so far is that the compiler attempts to infer the type of dims argument and makes an assertion that it is still a constant, but since I replaced it with a call to identity(:dims), the compiler pass fails.

Does this theory sound reasonable? If so, is there some metadata about :dims var that I should update to make this work?


MWE:

import IRTools: IR, @code_ir, @dynamo, self, var


const PRIMITIVES = Set([
    Core.kwfunc(sum),
    Core.apply_type,
])


mutable struct IRTracer
    primitives::Set{Any}
    ops::Vector{Any}
end

function IRTracer(;primitives=PRIMITIVES)
    return IRTracer(primitives, [])
end

Base.show(io::IO, t::IRTracer) = print(io, "IRTracer($(length(t.ops)))")



function record_or_recurse!(t::IRTracer, fargs...)
    fn, args = fargs[1], fargs[2:end]
    if fn in t.primitives || (fn isa Type && fn <: NamedTuple)
        res = fn(args...)
        push!(t.ops, fargs)
    else
        res = t(fn, args...)
    end
    return res
end


@dynamo function (t::IRTracer)(fargs...)
    ir = IR(fargs...)
    ir == nothing && return   # intrinsic functions
    for (v, st) in ir
        ex = st.expr
        if Meta.isexpr(ex, :call)
            ir[v] = Expr(:call, record_or_recurse!, self, ex.args...)
        else            
            ir[v] = Expr(:call, identity, ex)
        end
    end
    return ir
end

x = rand(2, 4)
t = IRTracer()
f = x -> sum(x; dims=1)

t(f, x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant