You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use DEQ models to fit 1 toy data point by wrapping the neural net in a SteadyStateProblem. However, when I run a piece of code like this
using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
u0 = Float32[0.0; 0.0]
tspan = (0.0f0, 10.0f0)
ann = FastChain(
FastDense(4, 4, relu),
FastDense(4, 4, tanh))
p1 = initial_params(ann)
n = length(p1)
ps = Float32[p1;u0]
function dudt_(du, u, p, t)
# Solving the equation f(u) - u = du = 0
du = ann(u, p[1:n]) - u
end
ode = ODEProblem(dudt_, u0, tspan, ps)
ss = SteadyStateProblem(ode)
function predict(x)
Array(solve(ss, DynamicSS(Rodas5()), u0 = [u0;x], p = ps, sensealg=SteadyStateAdjoint()))
end
# https://medium.com/coffee-in-a-klein-bottle/deep-learning-with-julia-e7f15ad5080b
#Auxiliary functions for generating our data
function generate_real_data(n)
x1 = rand(1,n) .- 0.5
x2 = (x1 .* x1)*3 .+ randn(1,n)*0.1
return vcat(x1,x2)
end
function generate_fake_data(n)
θ = 2*π*rand(1,n)
r = rand(1,n)/3
x1 = @. r*cos(θ)
x2 = @. r*sin(θ)+0.5
return vcat(x1,x2)
end
# Creating our data
train_size = 1
real = generate_real_data(train_size)
fake = generate_fake_data(train_size)
# Organizing the data in batches
X = hcat(real,fake)
temp = vcat(ones(train_size),zeros(train_size))
Y = vcat(temp, temp, temp, temp)
data = Flux.Data.DataLoader((X, reshape(Y, 4, size(X)[2])), batchsize=1,shuffle=true)
opt = Descent(0.05)
function loss(x, y)
ŷ = predict(x)
sum((y .- ŷ).^2)
end
epochs = 100
for i in 1:epochs
Flux.train!(loss, ps, data, opt)
println(mean(ann(real)),mean(ann(fake))) # Print model prediction
end
I get the following error. Any help would be appreciated!
using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using Statistics
u0 = Float32[0.0; 0.0]
tspan = (0.0f0, 10.0f0)
ann =FastChain(
FastDense(4, 4, relu),
FastDense(4, 4, tanh))
p1 =initial_params(ann)
n =length(p1)
ps = Float32[p1;u0]
dudt_(u, p, t) =ann(u, p[1:n]) - u
# Or#function dudt_(du, u, p, t)# Solving the equation f(u) - u = du = 0# du .= ann(u, p[1:n]) - u#end
ode =ODEProblem(dudt_, u0, tspan, ps)
ss =SteadyStateProblem(ode)
functionpredict(x)
Array(solve(ss, DynamicSS(Rodas5()), u0 = [u0;x], p = ps, sensealg=SteadyStateAdjoint()))
end# https://medium.com/coffee-in-a-klein-bottle/deep-learning-with-julia-e7f15ad5080b#Auxiliary functions for generating our datafunctiongenerate_real_data(n)
x1 =rand(n) .-0.5
x2 = (x1 .* x1)*3.+randn(n)*0.1returnvcat(x1,x2)
endfunctiongenerate_fake_data(n)
θ =2*π*rand(n)
r =rand(n)/3
x1 =@. r*cos(θ)
x2 =@. r*sin(θ)+0.5returnvcat(x1,x2)
end# Creating our data
train_size =1
real =generate_real_data(train_size)
fake =generate_fake_data(train_size)
# Organizing the data in batches
X =hcat(real,fake)
temp =vcat(ones(train_size),zeros(train_size))
Y =vcat(temp, temp, temp, temp)
data = Flux.Data.DataLoader((X, reshape(Y, 4, size(X)[2])), batchsize=1,shuffle=true)
opt =ADAM(0.05)
functionloss(x, y)
ŷ =predict(x)
@showsum((y .- ŷ).^2)
end
epochs =1000for i in1:epochs
Flux.train!(loss, Flux.Params([ps]), data, opt)
#println(mean(ann([u0;real],ps[1:n])),mean(ann([u0;fake],ps[1:n]))) # Print model predictionend
Works. The main issue, that stack overflow, was a recently introduced issue with a quick fix: SciML/SciMLBase.jl#56 which is now tagged. Other issues you had in here:
du .= ann(u, p[1:n]) - u. Remember that f(du,u,p,t) is a mutating function, so it needs to mutate the output. In this case it makes more sense to use the non-mutating form f(u,p,t), i.e. dudt_(u, p, t) = ann(u, p[1:n]) - u.
println(mean(ann([u0;real],ps[1:n])),mean(ann([u0;fake],ps[1:n]))) whatever you were printing wasn't taking in the values.
Hi all,
I am trying to use DEQ models to fit 1 toy data point by wrapping the neural net in a SteadyStateProblem. However, when I run a piece of code like this
I get the following error. Any help would be appreciated!
The text was updated successfully, but these errors were encountered: