Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradients with aliased variables #991

Open
CarloLucibello opened this issue Jun 10, 2021 · 9 comments
Open

gradients with aliased variables #991

CarloLucibello opened this issue Jun 10, 2021 · 9 comments

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 10, 2021

I was trying to figure out how to properly handle and update Flux's layers with tied weights ( FluxML/Flux.jl#1592).

So first of all I wanted to check how Zygote handles aliased objects. Here are 6 examples. Maybe it's all expected and intended but I find the last 3 in particular a bit surprising.
@oxinabox is this what we want?

julia> using Zygote

julia> x = [1]
1-element Vector{Int64}:
 1

julia> xt = x'
1×1 adjoint(::Vector{Int64}) with eltype Int64:
 1

# 1.
julia> gradient(() -> sum(x' .* x), Params([x])).grads
IdDict{Any, Any} with 2 entries:
  :(Main.x) => [2]
  [1]       => [2]

# 2.
julia> gradient(() -> sum(xt .* x), Params([x])).grads
IdDict{Any, Any} with 3 entries:
  :(Main.x)  => [1]
  [1]        => [1]
  :(Main.xt) => [1]

# 3.
julia> gradient(() -> sum(xt .* x), Params([x,xt])).grads
IdDict{Any, Any} with 4 entries:
  [1]        => [1]
  :(Main.x)  => [1]
  [1]        => [1]
  :(Main.xt) => [1]

# 4.
julia> gradient(() -> sum(xt.parent .* x), Params([x])).grads
IdDict{Any, Any} with 2 entries:
  :(Main.x) => [1]
  [1]       => [2]

# 5.
julia> gradient(() -> sum(xt.parent .* x), Params([x, xt])).grads
IdDict{Any, Any} with 3 entries:
  [1]       => nothing  # this is xt
  :(Main.x) => [1]
  [1]       => [2]           # this is x

#6.
julia> gradient(() -> sum(xt.parent .* x), Params([xt])).grads
IdDict{Any, Any} with 3 entries:
  [1]        => (parent = [1],)
  :(Main.x)  => [1]
  :(Main.xt) => (parent = [1],)
@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jun 10, 2021

I guess the most disturbing is 5., shouldn't return

  [1]       => (parent = [1],)  # this is xt
  :(Main.x) => [1]
  [1]       => [1]           # this is x

instead?

@oxinabox
Copy link
Member

putting aliased memory in Params feels like its not going to be ok.
I would need a fair bit of time to think about these.

@darsnack
Copy link
Member

darsnack commented Jun 18, 2021

(never mind, the thing I was missing is scribbling the wrong variables on my napkin)

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jun 19, 2021

For a user define struct we have

julia> struct A; x; end

julia> x = rand(2); a = A(x);

julia> Base.sum(a) = sum(a.x)

julia> gradient(() -> sum(a), Params([x])).grads
IdDict{Any, Any} with 1 entry:
  [0.573261, 0.457937] => 2-element Fill{Float64}: entries equal to 1.0

while for Adjoint something is wrong

julia> xt = Adjoint(x)
1×2 adjoint(::Vector{Float64}) with eltype Float64:
 0.573261  0.457937

julia> gradient(() -> sum(xt), Params([x])).grads
IdDict{Any, Any} with 2 entries:
  [0.573261, 0.457937] => nothing
  :(Main.xt)           => 1×2 Fill{Float64}: entries equal to 1.0

@DhairyaLGandhi
Copy link
Member

This seems expected... The grads actually also track global params as a GlobalRef to capture tied variables.

@darsnack
Copy link
Member

darsnack commented Jun 19, 2021

They make sense, but that doesn't make them right/useful.

I tried creating similar problems with explicit params yesterday, and I just could not find an example that didn't work. So rather than spend time fixing this issue, we could transition to explicit params across the ecosystem.

@CarloLucibello
Copy link
Member Author

Seems hard to not consider last example in #991 (comment) a bug. I don't even know precisely why it happens, probably when we hit an AbstractArray{<:Number} in Zygote we don't look for internal structure, is that the case?

I tried creating similar problems with explicit params yesterday, and I just could not find an example that didn't work. So rather than spend time fixing this issue, we could transition to explicit params across the ecosystem.

I'm not totally sure explicit gradient is a convenient fit for every situation, I'd like to see a diverse set of use cases where it replaces params. In last example, explicit gradient is at least consistent , although not quite useful

julia> gradient(x -> sum(a), x)
(nothing,)

julia> gradient(x -> sum(xt), x)
(nothing,)

@darsnack
Copy link
Member

I think this illustrates why I consider explicit params better. It's obvious why the last example returned nothing. For the same reason, the Adjoint case returns nothing, but it is less obvious because we expect implicit params to pick up connections that aren't there in the function being differentiated.

One option is to add some kind of post-processing step where Params finds these connections and applies a fix. But I feel that it hard to do in the generic case correctly.

@darsnack
Copy link
Member

For example, something like FluxML/Flux.jl#1592 works out nicely. Similar to the examples above, if we have

m1 = Dense(5, 2)
m2 = Dense(transpose(m1.weight))
m = Chain(m1, m2)
dm = gradient(m -> sum(m(ones(Float32, 5))), m)[1]

Zygote will see the weight of m1 as w1 = w and m2 as w2 = transpose(w). It returns gradients w.r.t. w1 and w2 (as if they are not tied). But when we consider the part that Zygote doesn't see (w1 = w), we have

from multivariate chain rule
dL/dw = dL/dw1 * dw1/dw + dL/dw2 * dw2/dw
dw1/dw = 1
dw2/dw = 1 (up to transpose)

=> dL/dw = dL/dw1 + dL/dw2

The last equation is automatically done by simple optimizers like gradient descent provided you use lazy wrappers like transpose or views. (@oxinabox can correct me if I am wrong here, my AD knowledge is very limited).

I guess it isn't automatic for complex optimizers that track momentum, etc. But it seems like then we should be handling it on the optimizer side, not the AD. This is where I think explicit params is nicer. What I wrote above is true for implicit params as well (e.g. Example 3 in the main issue) when Params contains x, xt. The trouble with implicit params is that you get all these other cases, issues with hashing, etc. that make dealing with the final equation I wrote above harder on the optimizer side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants