Skip to content

Commit

Permalink
Merge pull request #270 from MilesCranmer/auto-heap-size
Browse files Browse the repository at this point in the history
Automatically set heap size hint on workers
  • Loading branch information
MilesCranmer authored Dec 24, 2023
2 parents 93b1b26 + 0becbf4 commit c878d66
Show file tree
Hide file tree
Showing 12 changed files with 469 additions and 391 deletions.
104 changes: 74 additions & 30 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,21 @@ function assert_operators_well_defined(T, options::Options)
end

# Check for errors before they happen
function test_option_configuration(T, options::Options)
function test_option_configuration(
parallelism, datasets::Vector{D}, saved_state, options::Options
) where {T,D<:Dataset{T}}
if options.deterministic && parallelism != :serial
error("Determinism is only guaranteed for serial mode.")
end
if parallelism == :multithreading && Threads.nthreads() == 1
@warn "You are using multithreading mode, but only one thread is available. Try starting julia with `--threads=auto`."
end
if any(d -> d.X_units !== nothing || d.y_units !== nothing, datasets) &&
options.dimensional_constraint_penalty === nothing &&
saved_state === nothing
@warn "You are using dimensional constraints, but `dimensional_constraint_penalty` was not set. The default penalty of `1000.0` will be used."
end

for op in (options.operators.binops..., options.operators.unaops...)
if is_anonymous_function(op)
throw(
Expand Down Expand Up @@ -81,25 +95,18 @@ function test_dataset_configuration(
)
end

if size(dataset.X, 2) > 10000
if !options.batching
debug(
verbosity > 0,
"Note: you are running with more than 10,000 datapoints. You should consider turning on batching (`options.batching`), and also if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form.",
)
end
if size(dataset.X, 2) > 10000 && !options.batching && verbosity > 0
@info "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (`options.batching`), and also if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form."
end

if !(typeof(options.elementwise_loss) <: SupervisedLoss)
if dataset.weighted
if !(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)])
throw(
AssertionError(
"When you create a custom loss function, and are using weights, you need to define your loss function with three scalar arguments: f(prediction, target, weight).",
),
)
end
end
if !(typeof(options.elementwise_loss) <: SupervisedLoss) &&
dataset.weighted &&
!(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)])
throw(
AssertionError(
"When you create a custom loss function, and are using weights, you need to define your loss function with three scalar arguments: f(prediction, target, weight).",
),
)
end
end

Expand Down Expand Up @@ -188,14 +195,15 @@ end

function copy_definition_to_workers(op, procs, options::Options, verbosity)
name = nameof(op)
debug_inline(verbosity > 0, "Copying definition of $op to workers...")
verbosity > 0 && @info "Copying definition of $op to workers..."
src_ms = methods(op).ms
# Thanks https://discourse.julialang.org/t/easy-way-to-send-custom-function-to-distributed-workers/22118/2
@everywhere procs @eval function $name end
for m in src_ms
@everywhere procs @eval $m
end
return debug(verbosity > 0, "Finished!")
verbosity > 0 && @info "Finished!"
return nothing
end

function test_function_on_workers(example_inputs, op, procs)
Expand All @@ -209,7 +217,7 @@ function test_function_on_workers(example_inputs, op, procs)
end

function activate_env_on_workers(procs, project_path::String, options::Options, verbosity)
debug(verbosity > 0, "Activating environment on workers.")
verbosity > 0 && @info "Activating environment on workers."
@everywhere procs begin
Base.MainInclude.eval(
quote
Expand All @@ -223,7 +231,7 @@ end
function import_module_on_workers(procs, filename::String, options::Options, verbosity)
included_local = !("SymbolicRegression" in [k.name for (k, v) in Base.loaded_modules])
if included_local
debug_inline(verbosity > 0, "Importing local module ($filename) on workers...")
verbosity > 0 && @info "Importing local module ($filename) on workers..."
@everywhere procs begin
# Parse functions on every worker node
Base.MainInclude.eval(
Expand All @@ -233,18 +241,18 @@ function import_module_on_workers(procs, filename::String, options::Options, ver
end,
)
end
debug(verbosity > 0, "Finished!")
verbosity > 0 && @info "Finished!"
else
debug_inline(verbosity > 0, "Importing installed module on workers...")
verbosity > 0 && @info "Importing installed module on workers..."
@everywhere procs begin
Base.MainInclude.eval(using SymbolicRegression)
Base.MainInclude.eval(:(using SymbolicRegression))
end
debug(verbosity > 0, "Finished!")
verbosity > 0 && @info "Finished!"
end
end

function test_module_on_workers(procs, options::Options, verbosity)
debug_inline(verbosity > 0, "Testing module on workers...")
verbosity > 0 && @info "Testing module on workers..."
futures = []
for proc in procs
push!(
Expand All @@ -255,14 +263,15 @@ function test_module_on_workers(procs, options::Options, verbosity)
for future in futures
fetch(future)
end
return debug(verbosity > 0, "Finished!")
verbosity > 0 && @info "Finished!"
return nothing
end

function test_entire_pipeline(
procs, dataset::Dataset{T}, options::Options, verbosity
) where {T<:DATA_TYPE}
futures = []
debug_inline(verbosity > 0, "Testing entire pipeline on workers...")
verbosity > 0 && @info "Testing entire pipeline on workers..."
for proc in procs
push!(
futures,
Expand Down Expand Up @@ -293,5 +302,40 @@ function test_entire_pipeline(
for future in futures
fetch(future)
end
return debug(verbosity > 0, "Finished!")
verbosity > 0 && @info "Finished!"
return nothing
end

function configure_workers(;
procs::Union{Vector{Int},Nothing},
numprocs::Int,
addprocs_function::Function,
options::Options,
project_path,
file,
exeflags::Cmd,
verbosity,
example_dataset::Dataset,
runtests::Bool,
)
(procs, we_created_procs) = if procs === nothing
(addprocs_function(numprocs; lazy=false, exeflags), true)
else
(procs, false)
end

if we_created_procs
activate_env_on_workers(procs, project_path, options, verbosity)
import_module_on_workers(procs, file, options, verbosity)
end
move_functions_to_workers(procs, options, example_dataset, verbosity)
if runtests
test_module_on_workers(procs, options, verbosity)
end

if runtests
test_entire_pipeline(procs, example_dataset, options, verbosity)
end

return (procs, we_created_procs)
end
13 changes: 9 additions & 4 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import DynamicQuantities:
DEFAULT_DIM_BASE_TYPE
import ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units

import ..UtilsModule: subscriptify
import ..UtilsModule: subscriptify, get_base_type
import ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
#! format: off
import ...deprecate_varmap
Expand Down Expand Up @@ -103,12 +103,12 @@ function Dataset(
display_variable_names=variable_names,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type{Linit}=Nothing,
loss_type::Type{L}=Nothing,
X_units::Union{AbstractVector,Nothing}=nothing,
y_units=nothing,
# Deprecated:
varMap=nothing,
) where {T<:DATA_TYPE,Linit}
) where {T<:DATA_TYPE,L}
Base.require_one_based_indexing(X)
y !== nothing && Base.require_one_based_indexing(y)
# Deprecation warning:
Expand Down Expand Up @@ -142,7 +142,12 @@ function Dataset(
sum(y) / n
end
end
out_loss_type = (Linit === Nothing) ? T : Linit
out_loss_type = if L === Nothing
T <: Complex ? get_base_type(T) : T
else
L
end

use_baseline = true
baseline = one(out_loss_type)
y_si_units = get_si_units(T, y_units)
Expand Down
11 changes: 4 additions & 7 deletions src/HallOfFame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import DynamicExpressions: Node, string_tree
import ..UtilsModule: split_string
import ..CoreModule: MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu
import ..ComplexityModule: compute_complexity
import ..PopMemberModule: PopMember, copy_pop_member
import ..PopMemberModule: PopMember
import ..LossFunctionsModule: eval_loss
import ..InterfaceDynamicExpressionsModule: format_dimensions
using Printf: @sprintf
Expand Down Expand Up @@ -60,12 +60,9 @@ function HallOfFame(
)
end

function copy_hall_of_fame(
hof::HallOfFame{T,L}
)::HallOfFame{T,L} where {T<:DATA_TYPE,L<:LOSS_TYPE}
function Base.copy(hof::HallOfFame)
return HallOfFame(
[copy_pop_member(member) for member in hof.members],
[exists for exists in hof.exists],
[copy(member) for member in hof.members], [exists for exists in hof.exists]
)
end

Expand Down Expand Up @@ -98,7 +95,7 @@ function calculate_pareto_frontier(
end
end
if betterThanAllSmaller
push!(dominating, copy_pop_member(member))
push!(dominating, copy(member))
end
end
return dominating
Expand Down
7 changes: 7 additions & 0 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ function modelexpr(model_name::Symbol)
numprocs::Union{Int,Nothing} = nothing
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing
runtests::Bool = true
loss_type::L = Nothing
selection_method::Function = choose_best
Expand Down Expand Up @@ -166,6 +167,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options)
numprocs=m.numprocs,
procs=m.procs,
addprocs_function=m.addprocs_function,
heap_size_hint_in_bytes=m.heap_size_hint_in_bytes,
runtests=m.runtests,
saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state),
return_state=true,
Expand Down Expand Up @@ -490,6 +492,11 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
which is the number of processes to use, as well as the `lazy` keyword argument.
For example, if set up on a slurm cluster, you could pass
`addprocs_function = addprocs_slurm`, which will set up slurm processes.
- `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint`
flag on Julia processes, recommending garbage collection once a process
is close to the recommended size. This is important for long-running distributed
jobs where each process has an independent memory, and can help avoid
out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
- `runtests::Bool=true`: Whether to run (quick) tests before starting the
search, to see if there will be any problems during the equation search
related to the host environment.
Expand Down
7 changes: 3 additions & 4 deletions src/Migration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MigrationModule
using StatsBase: StatsBase
import ..CoreModule: Options, DATA_TYPE, LOSS_TYPE
import ..PopulationModule: Population
import ..PopMemberModule: PopMember, copy_pop_member_reset_birth
import ..PopMemberModule: PopMember, reset_birth!
import ..UtilsModule: poisson_sample

"""
Expand Down Expand Up @@ -33,9 +33,8 @@ function migrate!(
migrants = StatsBase.sample(migrant_candidates, num_replace; replace=true)

for (i, migrant) in zip(locations, migrants)
base_pop.members[i] = copy_pop_member_reset_birth(
migrant; deterministic=options.deterministic
)
base_pop.members[i] = copy(migrant)
reset_birth!(base_pop.members[i]; options.deterministic)
end
return nothing
end
Expand Down
15 changes: 5 additions & 10 deletions src/PopMember.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ function PopMember(
)
end

function copy_pop_member(
p::PopMember{T,L}
)::PopMember{T,L} where {T<:DATA_TYPE,L<:LOSS_TYPE}
tree = copy_node(p.tree)
function Base.copy(p::PopMember{T,L})::PopMember{T,L} where {T<:DATA_TYPE,L<:LOSS_TYPE}
tree = copy(p.tree)
score = copy(p.score)
loss = copy(p.loss)
birth = copy(p.birth)
Expand All @@ -122,12 +120,9 @@ function copy_pop_member(
return PopMember{T,L}(tree, score, loss, birth, complexity, ref, parent)
end

function copy_pop_member_reset_birth(
p::PopMember{T,L}; deterministic::Bool
)::PopMember{T,L} where {T<:DATA_TYPE,L<:LOSS_TYPE}
new_member = copy_pop_member(p)
new_member.birth = get_birth_order(; deterministic=deterministic)
return new_member
function reset_birth!(p::PopMember; deterministic::Bool)
p.birth = get_birth_order(; deterministic)
return p
end

# Can read off complexity directly from pop members
Expand Down
6 changes: 3 additions & 3 deletions src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import ..ComplexityModule: compute_complexity
import ..LossFunctionsModule: score_func, update_baseline_loss!
import ..AdaptiveParsimonyModule: RunningSearchStatistics
import ..MutationFunctionsModule: gen_random_tree
import ..PopMemberModule: PopMember, copy_pop_member
import ..PopMemberModule: PopMember
import ..UtilsModule: bottomk_fast, argmin_fast
# A list of members of the population, with easy constructors,
# which allow for random generation of new populations
Expand Down Expand Up @@ -92,8 +92,8 @@ function Population(
)
end

function copy_population(pop::P)::P where {P<:Population}
return Population([copy_pop_member(pm) for pm in pop.members])
function Base.copy(pop::P)::P where {P<:Population}
return Population([copy(pm) for pm in pop.members])
end

# Sample random members of the population, and make a new one
Expand Down
1 change: 0 additions & 1 deletion src/RegularizedEvolution.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module RegularizedEvolutionModule

import Random: shuffle!
import DynamicExpressions: string_tree
import ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
import ..PopMemberModule: PopMember
Expand Down
Loading

0 comments on commit c878d66

Please sign in to comment.