-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed NeuralDSDE #724
Fixed NeuralDSDE #724
Conversation
Fixed dereferencing of parameters in NeuralDSDE here for SciML/SciMLSensitivity.jl#623 to work , Lux compatible constructors for all layers have been added in SciML/pull/722
Now it expects |
@@ -179,13 +179,13 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain} | |||
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) | |||
end | |||
|
|||
function (n::NeuralDSDE{M})(x,p1,p2,st1,st2) where {M<:Lux.AbstractExplicitLayer} | |||
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia can unpack it
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer} | |
function (n::NeuralDSDE{M})(x,(p1, p2),st1,st2) where {M<:Lux.AbstractExplicitLayer} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we need them to be packed in one p so that we can pass it to the SDEProblem
https://github.com/Abhishek-1Bhatt/DiffEqFlux.jl/blob/b3b7ae7208f03918bdaaf67b39b6775515dcc6e0/src/neural_de.jl#L193
which will then call our functions with p where we will dereference individual p1 and p2 from p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's why I added parentheses it works with one p
and Julia will unpack it like
p1, p2 = p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but that is not the point where we want to break it apart(unpack it). If you check the line I referenced above, we want to have one p till there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry. Now, I understand!
Fixed dereferencing of parameters in NeuralDSDE here for SciML/SciMLSensitivity.jl#623 to work , Lux compatible constructors for all layers have been added in /pull/722