Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repeated calls to replace! have quadratic complexity #142

Open
mschauer opened this issue Mar 21, 2024 · 3 comments
Open

Repeated calls to replace! have quadratic complexity #142

mschauer opened this issue Mar 21, 2024 · 3 comments

Comments

@mschauer
Copy link

replace!(tape, i => [rr_op, val_op, pb_op]; rebind_to=2)

This is because the remainder of the band is rewritten every time if I understand correctly. Maybe there is a potential efficiency gain here. Ignore if already considered

@dfdx
Copy link
Owner

dfdx commented Mar 22, 2024

Yes, it's a known issue and a notable historical artifact. Initially, Yota was designed to work with its custom rules that required only adding new operations onto the tape and not replacing anything. However, with ChainRules we have to replace a single operation:

%2 = foo(%1)

with 3 operations:

%2 = rrule(foo, %1)
%3 = getfield(%2, 1)
%4 = getfield(%2, 2)

Since variable IDs change, we need to rebind!() all the affected operations (+ context), which by itself has quadratic complexity.

To avoid this effect, we should either drop ChainRules, or at least redesign tape to support multiple output variables. Unfortunately, both of these options are unrealistic at this point of development.

@mschauer
Copy link
Author

mschauer commented Mar 22, 2024

Thanks, for now I have my

%2 = my_rule(foo, %1)
%3 = getfield(%2, 1)
%4 = getfield(%2, 2)

in the Umlaut.record_primitive! code. Perhaps one can still do it in a single pass without redesigning the tape by avoiding the replace! function and keeping a dictionary of ID changes, I'll think about it. By the way, I wrote a message on Slack. Thank you for Umlaut.jl!

@dfdx
Copy link
Owner

dfdx commented Mar 22, 2024

True! I experimented with a single-pass AD for some time, but eventually decided to keep them separate because of corner cases. For example, Julia has a pretty special way to represent vararg functions:

julia> foo(xs) = print(xs...)
foo (generic function with 2 methods)

julia> Umlaut.getcode(foo, (Vector{Float64},))
1 1%1 = Core._apply_iterate(Base.iterate, Main.print, _2)::Core.Const(nothing) 
  │
  └──      return %1  

Usually, you don't want to handle things like Core._apply_iterate() yourself, check for consistency between inputs, etc., so Umlaut does most of the hard parts automatically:

julia> trace(foo, [1, 2, 3.0])
1.02.03.0(nothing, Tape{Umlaut.BaseCtx}
  inp %1::typeof(foo)
  inp %2::Vector{Float64}
  %3 = check_variable_length(%2, 3, 2)::Nothing 
  %4 = __to_tuple__(%2)::Tuple{Float64, Float64, Float64} 
  %5 = getfield(%4, 1)::Float64 
  %6 = getfield(%4, 2)::Float64 
  %7 = getfield(%4, 3)::Float64 
  %8 = print(%5, %6, %7)::Nothing 
)

But if you override too many of the "internal" functions like record_primitive!(), you risk to break this behavior. After a number of such issues, I decided to keep the simpler 2-stage design with forward and reverse pass separate.

On the other hand, you don't to cover all corner cases, so you have good chances to get it done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants