Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rough Optimisers sketch #17

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Rough Optimisers sketch #17

wants to merge 2 commits into from

Conversation

ericphanson
Copy link
Member

Serializing the full nested structure with Arrow

@femtomc
Copy link

femtomc commented May 13, 2022

@ericphanson what do you think about a load_leaves function (in analogy to load_weights!):

function load_leaves(state, leaves::Array{Optimisers.Leaf})
    counter = 1; fmap(state) do s
        if Functors.isleaf(s)
            new = leaves[counter]
            counter += 1
        else
            new = s
        end
        new
    end
end

Here, state is state from Optimisers.jl -- it's a NamedTuple. So this just traverses that state and replaces the leaves.

If there is an in-place version possible -- we might go for that, however.

1 similar comment
@femtomc
Copy link

femtomc commented May 13, 2022

@ericphanson what do you think about a load_leaves function (in analogy to load_weights!):

function load_leaves(state, leaves::Array{Optimisers.Leaf})
    counter = 1; fmap(state) do s
        if Functors.isleaf(s)
            new = leaves[counter]
            counter += 1
        else
            new = s
        end
        new
    end
end

Here, state is state from Optimisers.jl -- it's a NamedTuple. So this just traverses that state and replaces the leaves.

If there is an in-place version possible -- we might go for that, however.

@ericphanson
Copy link
Member Author

I think this could work too. I don’t really know how Optimisers works- is all the info always in the Leaf objects?

@ToucheSir
Copy link

ToucheSir commented May 13, 2022

is all the info always in the Leaf objects?

No, there is no structural information in the leaves so you're relying on a stable topology (depends on the use case, bias vs no bias has caused issues before) + traversal order (guaranteed for now).

@femtomc
Copy link

femtomc commented May 20, 2022

@ericphanson my suggestion above is not a good one -- I didn't realize that you just serialize the entire state (I thought you still needed to walk and get leaves).

@ericphanson
Copy link
Member Author

ericphanson commented May 20, 2022

Should we just go with this full Arrow serialization appoach then? We need to...

  • add Optimisers.jl to our package dependencies + add ArrowTypes methods via piracy
  • add roundtripping tests to runtests.jl for the simple model there
  • add roundtripping tests to example.jl for the more complicated model there - and get it to pass (it doesn't currently)
  • update the README to document how to serialize optimizers

OR:

  • add the small ArrowTypes.jl dependency to Optimisers + ArrowTypes methods for optimizer objects
  • add tests there
  • still add a roundtripping test to LegolasFlux's example to show how LegolasFlux can be used to serialize a model in combination with Arrow serialization of the optimizer
  • still update the LegolasFlux readme

@femtomc
Copy link

femtomc commented May 20, 2022

Normally, I would prefer the latter (not take on a new dependency, but instead upstream) -- but because this is LegolasFlux, and Optimisers.jl and Flux.jl are closely related and developed by overlapping parties -- perhaps we go ahead and do (1)

If your upstream apache/arrow-julia#323 gets removed, I think we can expunge any piracy here?

@femtomc
Copy link

femtomc commented May 20, 2022

Assuming that Flux models + optimisers are enough to cover almost all use cases of Flux (now including training state), it seems acceptable. I'm aware that it's a slippery slope to start including this sort of thing for an ecosystem which could continue to grow ... but perhaps this truly does cover almost all of the serialization use cases.

@ericphanson
Copy link
Member Author

If your upstream apache/arrow-julia#323 gets removed, I think we can expunge any piracy here?

No, the macro is ok either place, just a utility. The problem is LegolasFlux can’t define ArrowTypes methods for Optimisers objects. Either ArrowTypes or Optimisers has to do that for it not be piracy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants