Skip to content

Commit

Permalink
Merge pull request #287 from MilesCranmer/compathelper/new_version/20…
Browse files Browse the repository at this point in the history
…24-02-04-00-09-49-895-01848451222

Update DynamicExpressions.jl
  • Loading branch information
MilesCranmer authored Mar 9, 2024
2 parents 9438597 + 645141b commit 3250679
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 103 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"

[compat]
Aqua = "0.7"
Bumper = "0.6"
Compat = "^4.2"
DynamicExpressions = "0.15"
DynamicExpressions = "0.16"
DynamicQuantities = "0.10 - 0.12"
JSON3 = "1"
LineSearches = "7"
LoopVectorization = "0.12"
LossFunctions = "0.10, 0.11"
MLJModelInterface = "1.5, 1.6, 1.7, 1.8"
MacroTools = "0.4, 0.5"
Expand All @@ -58,6 +60,7 @@ julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -71,4 +74,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
50 changes: 34 additions & 16 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,26 +203,44 @@ function activate_env_on_workers(procs, project_path::String, options::Options,
end

function import_module_on_workers(procs, filename::String, options::Options, verbosity)
included_local = !("SymbolicRegression" in [k.name for (k, v) in Base.loaded_modules])
if included_local
verbosity > 0 && @info "Importing local module ($filename) on workers..."
@everywhere procs begin
# Parse functions on every worker node
Base.MainInclude.eval(
quote
include($$filename)
using .SymbolicRegression
end,
)
loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]

included_as_local = "SymbolicRegression" loaded_modules_head_worker
expr = if included_as_local
quote
include($filename)
using .SymbolicRegression
end
verbosity > 0 && @info "Finished!"
else
verbosity > 0 && @info "Importing installed module on workers..."
@everywhere procs begin
Base.MainInclude.eval(:(using SymbolicRegression))
quote
using SymbolicRegression
end
verbosity > 0 && @info "Finished!"
end

# Need to import any extension code, if loaded on head node
relevant_extensions = [
:SymbolicUtils, :Bumper, :LoopVectorization, :Zygote, :CUDA, :Enzyme
]
filter!(m -> String(m) loaded_modules_head_worker, relevant_extensions)
# HACK TODO – this workaround is very fragile. Likely need to submit a bug report
# to JuliaLang.

for ext in relevant_extensions
push!(
expr.args,
quote
using $ext: $ext
end,
)
end

verbosity > 0 && if isempty(relevant_extensions)
@info "Importing SymbolicRegression on workers."
else
@info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))."
end
@everywhere procs Base.MainInclude.eval($expr)
return verbosity > 0 && @info "Finished!"
end

function test_module_on_workers(
Expand Down
91 changes: 34 additions & 57 deletions src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,12 @@ module ConstantOptimizationModule

using LineSearches: LineSearches
using Optim: Optim
using DynamicExpressions: count_constants
using DynamicExpressions: Node, count_constants, get_constant_refs
using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
using ..UtilsModule: get_birth_order
using ..LossFunctionsModule: score_func, eval_loss, batch_sample
using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample
using ..PopMemberModule: PopMember

# Proxy function for optimization
@inline function opt_func(
x, dataset::Dataset{T,L}, tree, constant_nodes, options, idx
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
_set_constants!(x, constant_nodes)
# TODO(mcranmer): This should use score_func batching.
loss = eval_loss(tree, dataset, options; regularization=false, idx=idx)
return loss::L
end

function _set_constants!(x::AbstractArray{T}, constant_nodes) where {T}
for (xi, node) in zip(x, constant_nodes)
node.val::T = xi
end
return nothing
end

# Use Nelder-Mead to optimize the constants in an equation
function optimize_constants(
dataset::Dataset{T,L}, member::P, options::Options
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
Expand All @@ -42,62 +24,57 @@ function dispatch_optimize_constants(
) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
nconst = count_constants(member.tree)
nconst == 0 && return (member, 0.0)
if T <: Complex
# TODO: Make this more general. Also, do we even need Newton here at all??
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
elseif nconst == 1
if nconst == 1 && !(T <: Complex)
algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
else
if options.optimizer_algorithm == "NelderMead"
algorithm = Optim.NelderMead(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
elseif options.optimizer_algorithm == "BFGS"
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
else
error("Optimization function not implemented.")
end
end
return _optimize_constants(
dataset,
member,
options,
options.optimizer_algorithm,
options.optimizer_options,
idx,
)
end

function _optimize_constants(
dataset, member::P, options, algorithm, optimizer_options, idx
)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}}
tree = member.tree
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
x0 = [n.val::T for n in constant_nodes]
f(x) = opt_func(x, dataset, tree, constant_nodes, options, idx)
result = Optim.optimize(f, x0, algorithm, optimizer_options)
num_evals = 0.0
num_evals += result.f_calls
eval_fraction = options.batching ? (options.batch_size / dataset.n) : 1.0
f(t) = eval_loss(t, dataset, options; regularization=false, idx=idx)::L
baseline = f(tree)
result = Optim.optimize(f, tree, algorithm, optimizer_options)
num_evals = result.f_calls * eval_fraction
# Try other initial conditions:
for i in 1:(options.optimizer_nrestarts)
new_start = x0 .* (T(1) .+ T(1//2) * randn(T, size(x0, 1)))
tmpresult = Optim.optimize(f, new_start, algorithm, optimizer_options)
num_evals += tmpresult.f_calls
for _ in 1:(options.optimizer_nrestarts)
tmptree = copy(tree)
foreach(tmptree) do node
if node.degree == 0 && node.constant
node.val = (node.val) * (T(1) + T(1//2) * randn(T))
end
end
tmpresult = Optim.optimize(
f, tmptree, algorithm, optimizer_options; make_copy=false
)
num_evals += tmpresult.f_calls * eval_fraction

if tmpresult.minimum < result.minimum
result = tmpresult
end
end

if Optim.converged(result)
_set_constants!(result.minimizer, constant_nodes)
member.score, member.loss = score_func(dataset, member, options)
num_evals += 1
if result.minimum < baseline
member.tree = result.minimizer
member.loss = eval_loss(member.tree, dataset, options; regularization=true, idx=idx)
member.score = loss_to_score(
member.loss, dataset.use_baseline, dataset.baseline_loss, member, options
)
member.birth = get_birth_order(; deterministic=options.deterministic)
else
_set_constants!(x0, constant_nodes)
num_evals += eval_fraction
end

return member, num_evals
Expand Down
2 changes: 1 addition & 1 deletion src/DimensionalAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ end
@inline function deg0_eval(
x::AbstractVector{T}, x_units::Vector{Q}, t::AbstractExpressionNode{T}
) where {T,R,Q<:AbstractQuantity{T,R}}
t.constant && return WildcardQuantity{Q}(Quantity(t.val::T, R), true, false)
t.constant && return WildcardQuantity{Q}(Quantity(t.val, R), true, false)
return WildcardQuantity{Q}(
(@inbounds x[t.feature]) * (@inbounds x_units[t.feature]), false, false
)
Expand Down
16 changes: 9 additions & 7 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module InterfaceDynamicExpressionsModule
using Printf: @sprintf
using DynamicExpressions: DynamicExpressions
using DynamicExpressions: OperatorEnum, GenericOperatorEnum, AbstractExpressionNode
using DynamicExpressions.StringsModule: needs_brackets
using DynamicQuantities: dimension, ustrip
using ..CoreModule: Options
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
Expand Down Expand Up @@ -54,7 +55,9 @@ which speed up evaluation significantly.
function eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
return eval_tree_array(tree, X, options.operators; turbo=options.turbo, kws...)
return eval_tree_array(
tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...
)
end

"""
Expand Down Expand Up @@ -155,8 +158,8 @@ Convert an equation to a string.
return string_tree(
tree,
options.operators;
f_variable=(feature, vnames) -> string_variable(feature, vnames, X_sym_units),
f_constant=val -> string_constant(val, vprecision, "[⋅]"),
f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
f_constant=(val,) -> string_constant(val, vprecision, "[⋅]"),
variable_names=display_variable_names,
kws...,
)
Expand All @@ -165,7 +168,7 @@ Convert an equation to a string.
tree,
options.operators;
f_variable=string_variable,
f_constant=val -> string_constant(val, vprecision, ""),
f_constant=(val,) -> string_constant(val, vprecision, ""),
variable_names=display_variable_names,
kws...,
)
Expand All @@ -191,8 +194,7 @@ function string_variable(feature, variable_names, variable_units=nothing)
return base
end
function string_constant(val, ::Val{precision}, unit_placeholder) where {precision}
does_not_need_brackets = typeof(val) <: Real
if does_not_need_brackets
if typeof(val) <: Real
return sprint_precision(val, Val(precision)) * unit_placeholder
else
return "(" * string(val) * ")" * unit_placeholder
Expand Down Expand Up @@ -283,7 +285,7 @@ function define_alias_operators(operators)
end

function (tree::AbstractExpressionNode)(X, options::Options; kws...)
return tree(X, options.operators; turbo=options.turbo, kws...)
return tree(X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...)
end
function DynamicExpressions.EvaluationHelpersModule._grad_evaluator(
tree::AbstractExpressionNode, X, options::Options; kws...
Expand Down
1 change: 1 addition & 0 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MLJInterfaceModule

using Optim: Optim
using LineSearches: LineSearches
using MLJModelInterface: MLJModelInterface as MMI
using DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node
using DynamicQuantities:
Expand Down
6 changes: 3 additions & 3 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ function mutate_constant(
makeConstBigger = rand(Bool)

if makeConstBigger
node.val::T *= factor
node.val *= factor
else
node.val::T /= factor
node.val /= factor
end

if rand() > options.probability_negate_constant
node.val::T *= -1
node.val *= -1
end

return tree
Expand Down
Loading

0 comments on commit 3250679

Please sign in to comment.