Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiffEqFlux Layers don't satisfy Lux API #727

Closed
avik-pal opened this issue Jun 18, 2022 · 15 comments
Closed

DiffEqFlux Layers don't satisfy Lux API #727

avik-pal opened this issue Jun 18, 2022 · 15 comments

Comments

@avik-pal
Copy link
Member

avik-pal commented Jun 18, 2022

The DiffEqFlux Layers need to satisfy https://lux.csail.mit.edu/dev/api/core/#Lux.AbstractExplicitLayer else the parameters/states returned from Lux.setup be incorrect. As pointed out in slack

julia> ps, st = Lux.setup(rng, Chain(node,Dense(2=>3)))
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.11987843 -0.1679378; 0.36991563 0.41324985; 0.73272866 0.7062624], bias = Float32[0.0; 0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

ps.layer_1 should not be an empty NamedTuple

https://lux.csail.mit.edu/dev/manual/interface/ -- is the most recent manual for the interface

@YichengDWu
Copy link
Contributor

YichengDWu commented Jun 19, 2022

Should be easy

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model) 
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)
parameterlength(node::NeuralODE) = parameterlength(node.model)
statelength(node::NeuralODE) = statelength(node.model)

To make setup work not only for Chain but also directly on NeuralODE, we need to add

function setup(rng::AbstractRNG, node::NeuralODE)
    return (initialparameters(rng, node), initialstates(rng, node))
end

@ChrisRackauckas
Copy link
Member

Are you supposed to overload setup? I assume that should just follow from the interface.

@avik-pal
Copy link
Member Author

You just need to define

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model) 
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)

@ChrisRackauckas
Copy link
Member

We should put an abstract type on all of the AbstractNeuralDE types and then overload from there.

@YichengDWu
Copy link
Contributor

You just need to define

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model) 
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)

For it to work yes. Would it be nicer if the number of parameters could be printed automatically?

@YichengDWu
Copy link
Contributor

Are you supposed to overload setup? I assume that should just follow from the interface.

I was assuming NeuralODE was not a subtype of AbstractExplicitLayer. Should be nonnecessary if you are going to subtype it

@avik-pal
Copy link
Member Author

No even if you are not subtying initialparameters and initialstates are the only functions that need to be mandatorily implemented, parameterlength and statelength are optional. setup should never be extended

@YichengDWu
Copy link
Contributor

YichengDWu commented Jun 20, 2022

I would appreciate it if you could help me understand two questions:

  1. Is it still mandatory to implement initialstates if I just have one layer and just need to return NameTuple()? I have implemented some layers without it. Looks like they are just calling initialstates(::AbstractRNG, ::Any) = NamedTuple() in the source code.
  2. What are the bad consequences of extending setup?

@avik-pal
Copy link
Member Author

It is meant to satisfy an interface.

  1. You are right, the default for initialstates is NamedTuple(), but this is undocumented so this can be changed without it being considered breaking.
  2. Extending setup is not going to solve problems for most people and sets false expectation. For example, if you extend setup for a layer which is contained inside another layer. Calling Lux.setup on the outer layer, will cause the parameters and states for the internal custom layer to have empty parameters and states.

@YichengDWu
Copy link
Contributor

Highly appreciate the clarification you made.

@ChrisRackauckas
Copy link
Member

Flux doesn't care about the subtyping but Lux does, so we should subtype for Lux and then also make it a functor and we're 👍.

@ChrisRackauckas
Copy link
Member

Copying over from #735. All should be an AbstractExplicitLayer, which means they should do things exactly like Dense. They should have one state, take in a state, and return a state. They should take in a neural network definition and give you back a state from setup. Basically, it should act exactly like Dense does, and be able to perfectly swap in without any other code changes, and if not it's wrong. The only thing that should be different is the constructor for the layer.

@Abhishek-1Bhatt let me know if you need me to do the first one.

@avik-pal
Copy link
Member Author

Once it gets built, http://lux.csail.mit.edu/previews/PR70/manual/interface should describe the recommended Lux Interface. For DiffEqFlux, everything should really be a subtype of http://lux.csail.mit.edu/stable/api/core/#Lux.AbstractExplicitContainerLayer, and there would be no need to define initialparameters and initialstates. (Just a small heads up there will be a small breaking change for the Container Layers in v0.5 (which is still far out) )

@ba2tro
Copy link
Contributor

ba2tro commented Jun 23, 2022

Ahh yes, I had thought about it sometime ago https://julialang.slack.com/archives/C7T968HRU/p1655536943724979?thread_ts=1655535510.205359&cid=C7T968HRU but we didn't discuss it so ended up subtyping to AbstractExplicitLayer

@ChrisRackauckas
Copy link
Member

Done in #750

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants