Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new ReactantState architecture in super_simple_simulation.jl #12

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

giordano
Copy link
Collaborator

@giordano giordano commented Feb 8, 2025

I don't know if this is how you wanted to use ReactantState, but this seems to work for me now.

@giordano giordano requested a review from glwagner February 8, 2025 12:43
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm modulo seed


arch = GPU() # CPU() to run on CPU
arch = isempty(find_library(["libcuda.so.1", "libcuda.so"])) ? CPU() : GPU()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💪

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you also want to set the Reactant default to cpu vs gpu based on this check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Reactant does it automatically already.

@glwagner
Copy link
Collaborator

glwagner commented Feb 8, 2025

I suggest also adding a traced clock to these:

    FT = Float64
    t = ConcreteRNumber(zero(FT))
    iter = ConcreteRNumber(0)
    stage = ConcreteRNumber(0)
    last_Δt = ConcreteRNumber(zero(FT))
    last_stage_Δt = ConcreteRNumber(zero(FT))
    r_clock = Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt)

then pass clock into the constructor for HydrostaticFreeSurfaceModel:

r_model = HydrostaticFreeSurfaceModel(; grid=r_grid, clock=r_clock, momentum_advection=WENO())

Once we get the next PR through I will make Oceananigans do this automatically.

@giordano
Copy link
Collaborator Author

giordano commented Feb 8, 2025

I suggest also adding a traced clock to these:

I did it, but only for the Reactant model, is that right?

@giordano giordano force-pushed the mg/reactantstate branch 2 times, most recently from 02f2021 to 260e029 Compare February 8, 2025 16:03
@giordano
Copy link
Collaborator Author

giordano commented Feb 8, 2025

@glwagner with the clock:

ERROR: LoadError: MethodError: no method matching ifelse(::Reactant.TracedRNumber{Bool}, ::Reactant.TracedRNumber{Float64}, ::Float64)

Closest candidates are:
  ifelse(::Bool, ::Any, ::Any)
   @ Base essentials.jl:647
  ifelse(::Reactant.TracedRNumber{Bool}, ::Reactant.TracedRNumber{T}, ::Reactant.TracedRNumber{T}) where T
   @ Reactant ~/.julia/packages/Reactant/QDd2t/src/TracedRNumber.jl:184
  ifelse(::Reactant.TracedRNumber{Bool}, ::Reactant.TracedRNumber{T1}, ::Reactant.TracedRNumber{T2}) where {T1, T2}
   @ Reactant ~/.julia/packages/Reactant/QDd2t/src/TracedRNumber.jl:168

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/QDd2t/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::typeof(ifelse), ::Reactant.TracedRNumber{Bool}, ::Reactant.TracedRNumber{Float64}, ::Float64)
    @ Reactant ~/.julia/packages/Reactant/QDd2t/src/utils.jl:767
  [3] min
    @ ./operators.jl:490 [inlined]
  [4] min(none::Float64, none::Reactant.TracedRNumber{Float64})
    @ Reactant ./<missing>:0
  [5] promote_to
    @ ~/.julia/packages/Reactant/QDd2t/src/TracedRNumber.jl:86 [inlined]
  [6] isless
    @ ~/.julia/packages/Reactant/QDd2t/src/TracedRNumber.jl:145 [inlined]
  [7] min
    @ ./operators.jl:490 [inlined]
  [8] call_with_reactant(::typeof(min), ::Float64, ::Reactant.TracedRNumber{Float64})
    @ Reactant ~/.julia/packages/Reactant/QDd2t/src/utils.jl:0
  [9] aligned_time_step
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:50 [inlined]
 [10] aligned_time_step(none::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}}, none::Float64)
    @ Reactant ./<missing>:0
 [11] getproperty
    @ ./Base.jl:37 [inlined]
 [12] aligned_time_step
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:42 [inlined]
 [13] call_with_reactant(::typeof(Oceananigans.Simulations.aligned_time_step), ::Simulation{…}, ::Float64)
    @ Reactant ~/.julia/packages/Reactant/QDd2t/src/utils.jl:0
 [14] time_step!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:115 [inlined]
 [15] time_step!(none::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
    @ Reactant ./<missing>:0
 [16] time_ns
    @ ./Base.jl:114 [inlined]
 [17] time_step!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:113 [inlined]
 [18] call_with_reactant(::typeof(time_step!), ::Simulation{…})
    @ Reactant ~/.julia/packages/Reactant/QDd2t/src/utils.jl:0
 [19] #run!#7
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:102 [inlined]
 [20] var"#run!#7"(none::Bool, none::typeof(run!), none::Simulation{…})
    @ Reactant ./<missing>:0
 [21] #run!#7
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:93 [inlined]
 [22] call_with_reactant(::Oceananigans.Simulations.var"##run!#7", ::Bool, ::typeof(run!), ::Simulation{…})
    @ Reactant ~/.julia/packages/Reactant/QDd2t/src/utils.jl:0
 [23] run!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:91 [inlined]
 [24] run!(none::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
    @ Reactant ./<missing>:0
 [25] run!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:91 [inlined]
 [26] call_with_reactant(::typeof(run!), ::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
    @ Reactant ~/.julia/packages/Reactant/QDd2t/src/utils.jl:0
 [27] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/QDd2t/src/TracedUtils.jl:332
 [28] make_mlir_fn
    @ ~/.julia/packages/Reactant/QDd2t/src/TracedUtils.jl:152 [inlined]
 [29] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Simulation{…}}, callcache::Dict{Vector, @NamedTuple{…}}; optimize::Bool, no_nan::Bool, backend::String)
    @ Reactant.Compiler ~/.julia/packages/Reactant/QDd2t/src/Compiler.jl:591
 [30] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/QDd2t/src/Compiler.jl:566 [inlined]
 [31] compile_xla(f::Function, args::Tuple{Simulation{…}}; client::Nothing, kwargs::@Kwargs{no_nan::Bool, optimize::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QDd2t/src/Compiler.jl:1362
 [32] compile_xla
    @ ~/.julia/packages/Reactant/QDd2t/src/Compiler.jl:1341 [inlined]
 [33] compile(f::Function, args::Tuple{Simulation{…}}; sync::Bool, kwargs::@Kwargs{client::Nothing, no_nan::Bool, optimize::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QDd2t/src/Compiler.jl:1401
 [34] top-level scope
    @ ~/.julia/packages/Reactant/QDd2t/src/Compiler.jl:937
 [35] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [36] top-level scope
    @ REPL[3]:1
in expression starting at /Users/mose/repo/GB-25/oceananigans-dynamical-core/super_simple_simulation.jl:53
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Member

wsmoses commented Feb 8, 2025

We should add an override for that in Reactant, cc @avik-pal .

seems easy enough

@avik-pal
Copy link
Member

avik-pal commented Feb 8, 2025

EnzymeAD/Reactant.jl#712

@giordano
Copy link
Collaborator Author

giordano commented Feb 8, 2025

ERROR: LoadError: TypeError: non-boolean (Reactant.TracedRNumber{Bool}) used in boolean context
Stacktrace:
  [1] aligned_time_step
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:53 [inlined]
  [2] aligned_time_step(none::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}}, none::Float64)
    @ Reactant ./<missing>:0
  [3] getproperty
    @ ./Base.jl:37 [inlined]
  [4] aligned_time_step
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:42 [inlined]
  [5] call_with_reactant(::typeof(Oceananigans.Simulations.aligned_time_step), ::Simulation{…}, ::Float64)
    @ Reactant ~/.julia/packages/Reactant/uwEkW/src/utils.jl:0
  [6] time_step!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:115 [inlined]
  [7] time_step!(none::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
    @ Reactant ./<missing>:0
  [8] time_ns
    @ ./Base.jl:114 [inlined]
  [9] time_step!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:113 [inlined]
 [10] call_with_reactant(::typeof(time_step!), ::Simulation{…})
    @ Reactant ~/.julia/packages/Reactant/uwEkW/src/utils.jl:0
 [11] #run!#7
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:102 [inlined]
 [12] var"#run!#7"(none::Bool, none::typeof(run!), none::Simulation{…})
    @ Reactant ./<missing>:0
 [13] #run!#7
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:93 [inlined]
 [14] call_with_reactant(::Oceananigans.Simulations.var"##run!#7", ::Bool, ::typeof(run!), ::Simulation{…})
    @ Reactant ~/.julia/packages/Reactant/uwEkW/src/utils.jl:0
 [15] run!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:91 [inlined]
 [16] run!(none::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
    @ Reactant ./<missing>:0
 [17] run!
    @ ~/.julia/packages/Oceananigans/6qZwL/src/Simulations/run.jl:91 [inlined]
 [18] call_with_reactant(::typeof(run!), ::Simulation{HydrostaticFreeSurfaceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
    @ Reactant ~/.julia/packages/Reactant/uwEkW/src/utils.jl:0
 [19] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/uwEkW/src/TracedUtils.jl:332
 [20] make_mlir_fn
    @ ~/.julia/packages/Reactant/uwEkW/src/TracedUtils.jl:152 [inlined]
 [21] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Simulation{…}}, callcache::Dict{Vector, @NamedTuple{…}}; optimize::Bool, no_nan::Bool, backend::String)
    @ Reactant.Compiler ~/.julia/packages/Reactant/uwEkW/src/Compiler.jl:591
 [22] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/uwEkW/src/Compiler.jl:566 [inlined]
 [23] compile_xla(f::Function, args::Tuple{Simulation{…}}; client::Nothing, kwargs::@Kwargs{no_nan::Bool, optimize::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/uwEkW/src/Compiler.jl:1362
 [24] compile_xla
    @ ~/.julia/packages/Reactant/uwEkW/src/Compiler.jl:1341 [inlined]
 [25] compile(f::Function, args::Tuple{Simulation{…}}; sync::Bool, kwargs::@Kwargs{client::Nothing, no_nan::Bool, optimize::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/uwEkW/src/Compiler.jl:1401
 [26] top-level scope
    @ ~/.julia/packages/Reactant/uwEkW/src/Compiler.jl:937
 [27] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [28] top-level scope
    @ REPL[1]:1
in expression starting at /Users/mose/repo/GB-25/oceananigans-dynamical-core/super_simple_simulation.jl:53

@glwagner can we remove the clock for the time being? 😃

@giordano
Copy link
Collaborator Author

giordano commented Feb 8, 2025

@wsmoses
Copy link
Member

wsmoses commented Feb 8, 2025

@glwagner can that be made into an ifelse

@giordano
Copy link
Collaborator Author

giordano commented Feb 8, 2025

That helps only partially, there are more functions which use more complicated if conditionals: https://github.com/CliMA/Oceananigans.jl/blob/20085593ca3583645939e8d850d0aeffa766cb02/src/Simulations/simulation.jl#L190-L230.

@wsmoses
Copy link
Member

wsmoses commented Feb 8, 2025

Yeah those we’ll likely have to convert to be traceable by hand for the time being

cc @jumerckx @Pangoraw re auto tracing

@glwagner
Copy link
Collaborator

I missed a bunch of messages here for some reason. We can remove the "aligned time step" feature. We don't need it.

@glwagner
Copy link
Collaborator

That helps only partially, there are more functions which use more complicated if conditionals: https://github.com/CliMA/Oceananigans.jl/blob/20085593ca3583645939e8d850d0aeffa766cb02/src/Simulations/simulation.jl#L190-L230.

What pattern is supportable here?

On removing Clock, we cannot remove it completely for a climate case because we need to know the time of day/year for forcing the simulation. But we can use a different pattern for testing the code and bypass any part of Simulation that relies on the counter state. Many solutions are possible.

@giordano
Copy link
Collaborator Author

What pattern is supportable here?

I'm not sure. ifelse would work, but it wouldn't be particularly idiomatic here. In today's Reactant meeting we talked about how to automatically trace ifs, @jumerckx and @Pangoraw maybe can comment on the feasibility of a quick implementation.

@glwagner
Copy link
Collaborator

glwagner commented Feb 12, 2025

What pattern is supportable here?

I'm not sure. ifelse would work, but it wouldn't be particularly idiomatic here. In today's Reactant meeting we talked about how to automatically trace ifs, @jumerckx and @Pangoraw maybe can comment on the feasibility of a quick implementation.

I wonder if we should be taking a different approach where we do

r_time_step = @compile time_step!(model, dt)

and then Simulation runs with r_time_step.

Simulation is really just a utility for time-stepping loops so I don't know if we need to compile through with Reactant. I'm not 100% about this. For output and whatnot, we do have things outside time_step!(model, dt).

stage = ConcreteRNumber(0)
last_Δt = ConcreteRNumber(zero(FT))
last_stage_Δt = ConcreteRNumber(zero(FT))
r_clock = Clock(; time=t, iteration=iter, stage, last_Δt, last_stage_Δt)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glwagner to be clear, with CliMA/Oceananigans.jl#4096 how would I modify this? So that I can test it quickly locally

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants