Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling and memory consumption on GPU #687

Closed
shm117 opened this issue Feb 17, 2022 · 1 comment
Closed

Scaling and memory consumption on GPU #687

shm117 opened this issue Feb 17, 2022 · 1 comment

Comments

@shm117
Copy link

shm117 commented Feb 17, 2022

Inspired by the Neural ODE on GPU section of the documentation (https://diffeqflux.sciml.ai/stable/GPUs/), I tried to learn a NODE on data that is generated by a simulation of the heat equation in 1D. The execution time on GPU is roughly 10-20x higher per iteration than it is on the CPU. I expected this to be a lot faster, since the execution time with the tutorial data is around 1 second on the GPU and the data I used is not so big that it would justify more than 20 seconds per iteration.
Also, the memory consumption for a relatively small model with 2 layers and ca. 40k parameters is already about 3GB (observed with nvidia-smi), which is way higher than I expected. Am I doing something wrong here, or is this behavior expected?

Julia v1.7.2
DiffEqFlux v1.45.1

heat_ss.csv : each column is a snapshot in time, each row is a fixed position in the spatial domain

Edit1: I get the following warnings running the Code:
WARNING: both Flux and Iterators export "flatten"; uses of it in module DiffEqFlux must be qualified
WARNING: both Flux and Distributions export "params"; uses of it in module DiffEqFlux must be qualified

Edit2: loading the csv file as Float32 does not improve the performance substantially

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, CUDA, DiffEqSensitivity
using DelimitedFiles



# load data

heatData = readdlm("heat_ss.csv", ',', Float64)

samplefactor_state = 1
samplefactor_time = 1
[heat_ss.csv](https://github.com/SciML/DiffEqFlux.jl/files/8089740/heat_ss.csv)


ss_heatData = heatData[range(1, end, step=samplefactor_state), range(1, end, step=samplefactor_time)]

state_dim = size(ss_heatData)[1]

u0 = ss_heatData[:, 1] |> gpu
datasize = size(ss_heatData)[2]
tspan = (0.0f0, 20.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

CUDA.allowscalar(false) # Makes sure no slow operations are occuring

dudt2 = FastChain(
                  FastDense(state_dim, 200, tanh),
                  FastDense(200, state_dim)) #|> gpu

ss_heatData = ss_heatData|> gpu
p = initial_params(dudt2) |> gpu
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

function predict_neuralode(p)
    gpu(prob_neuralode(u0,p))
  end
  function loss_neuralode(p)
      pred = predict_neuralode(p)
      loss = sum(abs2, ss_heatData .- pred)  
      return loss, pred
  end

  # Callback function to observe training
list_plots = []
iter = 0
callback = function (p, l, pred; doplot = false)
  global list_plots, iter
  if iter == 0
    list_plots = []
  end
  iter += 1
  display(l)
   #plot current prediction against data
  plt = scatter(tsteps, Array(ss_heatData[1,:]), label = "data")
  scatter!(plt, tsteps, Array(pred[1,:]), label = "prediction")
  push!(list_plots, plt)
  if doplot
    display(plot(plt))
  end
  return false
end

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, p,
                                          ADAM(0.05), cb = callback,
                                          maxiters = 10)

´´´


@ChrisRackauckas
Copy link
Member

Because of how the data is setup here, it's doing matrix-vector products and thus not able to load the GPU kernels well. This is actually a good case for SimpleChains.jl and CPU, as described in more detail on https://julialang.org/blog/2022/04/simple-chains/, or one would want to fill the GPU by batching over time using multiple_shooting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants