Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DynamicExpressions.jl #287

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"

[compat]
Aqua = "0.7"
Bumper = "0.6"
Compat = "^4.2"
DynamicExpressions = "0.15"
DynamicExpressions = "0.16"
DynamicQuantities = "0.10 - 0.12"
JSON3 = "1"
LineSearches = "7"
LoopVectorization = "0.12"
LossFunctions = "0.10, 0.11"
MLJModelInterface = "1.5, 1.6, 1.7, 1.8"
MacroTools = "0.4, 0.5"
Expand All @@ -58,6 +60,7 @@ julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -71,4 +74,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
50 changes: 34 additions & 16 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,26 +203,44 @@ function activate_env_on_workers(procs, project_path::String, options::Options,
end

function import_module_on_workers(procs, filename::String, options::Options, verbosity)
included_local = !("SymbolicRegression" in [k.name for (k, v) in Base.loaded_modules])
if included_local
verbosity > 0 && @info "Importing local module ($filename) on workers..."
@everywhere procs begin
# Parse functions on every worker node
Base.MainInclude.eval(
quote
include($$filename)
using .SymbolicRegression
end,
)
loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]

included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker
expr = if included_as_local
quote
include($filename)
using .SymbolicRegression
end
verbosity > 0 && @info "Finished!"
else
verbosity > 0 && @info "Importing installed module on workers..."
@everywhere procs begin
Base.MainInclude.eval(:(using SymbolicRegression))
quote
using SymbolicRegression
end
verbosity > 0 && @info "Finished!"
end

# Need to import any extension code, if loaded on head node
relevant_extensions = [
:SymbolicUtils, :Bumper, :LoopVectorization, :Zygote, :CUDA, :Enzyme
]
filter!(m -> String(m) ∈ loaded_modules_head_worker, relevant_extensions)
# HACK TODO – this workaround is very fragile. Likely need to submit a bug report
# to JuliaLang.

for ext in relevant_extensions
push!(
expr.args,
quote
using $ext: $ext
end,
)
end

verbosity > 0 && if isempty(relevant_extensions)
@info "Importing SymbolicRegression on workers."
else
@info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))."
end
@everywhere procs Base.MainInclude.eval($expr)
return verbosity > 0 && @info "Finished!"
end

function test_module_on_workers(
Expand Down
91 changes: 34 additions & 57 deletions src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,12 @@ module ConstantOptimizationModule

using LineSearches: LineSearches
using Optim: Optim
using DynamicExpressions: count_constants
using DynamicExpressions: Node, count_constants, get_constant_refs
using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
using ..UtilsModule: get_birth_order
using ..LossFunctionsModule: score_func, eval_loss, batch_sample
using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample
using ..PopMemberModule: PopMember

# Proxy function for optimization
@inline function opt_func(
x, dataset::Dataset{T,L}, tree, constant_nodes, options, idx
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
_set_constants!(x, constant_nodes)
# TODO(mcranmer): This should use score_func batching.
loss = eval_loss(tree, dataset, options; regularization=false, idx=idx)
return loss::L
end

function _set_constants!(x::AbstractArray{T}, constant_nodes) where {T}
for (xi, node) in zip(x, constant_nodes)
node.val::T = xi
end
return nothing
end

# Use Nelder-Mead to optimize the constants in an equation
function optimize_constants(
dataset::Dataset{T,L}, member::P, options::Options
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
Expand All @@ -42,62 +24,57 @@ function dispatch_optimize_constants(
) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
nconst = count_constants(member.tree)
nconst == 0 && return (member, 0.0)
if T <: Complex
# TODO: Make this more general. Also, do we even need Newton here at all??
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
elseif nconst == 1
if nconst == 1 && !(T <: Complex)
algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
else
if options.optimizer_algorithm == "NelderMead"
algorithm = Optim.NelderMead(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
elseif options.optimizer_algorithm == "BFGS"
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options, idx
)
else
error("Optimization function not implemented.")
end
end
return _optimize_constants(
dataset,
member,
options,
options.optimizer_algorithm,
options.optimizer_options,
idx,
)
end

function _optimize_constants(
dataset, member::P, options, algorithm, optimizer_options, idx
)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}}
tree = member.tree
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
x0 = [n.val::T for n in constant_nodes]
f(x) = opt_func(x, dataset, tree, constant_nodes, options, idx)
result = Optim.optimize(f, x0, algorithm, optimizer_options)
num_evals = 0.0
num_evals += result.f_calls
eval_fraction = options.batching ? (options.batch_size / dataset.n) : 1.0
f(t) = eval_loss(t, dataset, options; regularization=false, idx=idx)::L
baseline = f(tree)
result = Optim.optimize(f, tree, algorithm, optimizer_options)
num_evals = result.f_calls * eval_fraction
# Try other initial conditions:
for i in 1:(options.optimizer_nrestarts)
new_start = x0 .* (T(1) .+ T(1//2) * randn(T, size(x0, 1)))
tmpresult = Optim.optimize(f, new_start, algorithm, optimizer_options)
num_evals += tmpresult.f_calls
for _ in 1:(options.optimizer_nrestarts)
tmptree = copy(tree)
foreach(tmptree) do node
if node.degree == 0 && node.constant
node.val = (node.val) * (T(1) + T(1//2) * randn(T))
end
end
tmpresult = Optim.optimize(
f, tmptree, algorithm, optimizer_options; make_copy=false
)
num_evals += tmpresult.f_calls * eval_fraction

if tmpresult.minimum < result.minimum
result = tmpresult
end
end

if Optim.converged(result)
_set_constants!(result.minimizer, constant_nodes)
member.score, member.loss = score_func(dataset, member, options)
num_evals += 1
if result.minimum < baseline
member.tree = result.minimizer
member.loss = eval_loss(member.tree, dataset, options; regularization=true, idx=idx)
member.score = loss_to_score(
member.loss, dataset.use_baseline, dataset.baseline_loss, member, options
)
member.birth = get_birth_order(; deterministic=options.deterministic)
else
_set_constants!(x0, constant_nodes)
num_evals += eval_fraction
end

return member, num_evals
Expand Down
2 changes: 1 addition & 1 deletion src/DimensionalAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ end
@inline function deg0_eval(
x::AbstractVector{T}, x_units::Vector{Q}, t::AbstractExpressionNode{T}
) where {T,R,Q<:AbstractQuantity{T,R}}
t.constant && return WildcardQuantity{Q}(Quantity(t.val::T, R), true, false)
t.constant && return WildcardQuantity{Q}(Quantity(t.val, R), true, false)
return WildcardQuantity{Q}(
(@inbounds x[t.feature]) * (@inbounds x_units[t.feature]), false, false
)
Expand Down
16 changes: 9 additions & 7 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module InterfaceDynamicExpressionsModule
using Printf: @sprintf
using DynamicExpressions: DynamicExpressions
using DynamicExpressions: OperatorEnum, GenericOperatorEnum, AbstractExpressionNode
using DynamicExpressions.StringsModule: needs_brackets
using DynamicQuantities: dimension, ustrip
using ..CoreModule: Options
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
Expand Down Expand Up @@ -54,7 +55,9 @@ which speed up evaluation significantly.
function eval_tree_array(
tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
)
return eval_tree_array(tree, X, options.operators; turbo=options.turbo, kws...)
return eval_tree_array(
tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...
)
end

"""
Expand Down Expand Up @@ -155,8 +158,8 @@ Convert an equation to a string.
return string_tree(
tree,
options.operators;
f_variable=(feature, vnames) -> string_variable(feature, vnames, X_sym_units),
f_constant=val -> string_constant(val, vprecision, "[⋅]"),
f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
f_constant=(val,) -> string_constant(val, vprecision, "[⋅]"),
variable_names=display_variable_names,
kws...,
)
Expand All @@ -165,7 +168,7 @@ Convert an equation to a string.
tree,
options.operators;
f_variable=string_variable,
f_constant=val -> string_constant(val, vprecision, ""),
f_constant=(val,) -> string_constant(val, vprecision, ""),
variable_names=display_variable_names,
kws...,
)
Expand All @@ -191,8 +194,7 @@ function string_variable(feature, variable_names, variable_units=nothing)
return base
end
function string_constant(val, ::Val{precision}, unit_placeholder) where {precision}
does_not_need_brackets = typeof(val) <: Real
if does_not_need_brackets
if typeof(val) <: Real
return sprint_precision(val, Val(precision)) * unit_placeholder
else
return "(" * string(val) * ")" * unit_placeholder
Expand Down Expand Up @@ -283,7 +285,7 @@ function define_alias_operators(operators)
end

function (tree::AbstractExpressionNode)(X, options::Options; kws...)
return tree(X, options.operators; turbo=options.turbo, kws...)
return tree(X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...)
end
function DynamicExpressions.EvaluationHelpersModule._grad_evaluator(
tree::AbstractExpressionNode, X, options::Options; kws...
Expand Down
1 change: 1 addition & 0 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MLJInterfaceModule

using Optim: Optim
using LineSearches: LineSearches
using MLJModelInterface: MLJModelInterface as MMI
using DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node
using DynamicQuantities:
Expand Down
6 changes: 3 additions & 3 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ function mutate_constant(
makeConstBigger = rand(Bool)

if makeConstBigger
node.val::T *= factor
node.val *= factor
else
node.val::T /= factor
node.val /= factor
end

if rand() > options.probability_negate_constant
node.val::T *= -1
node.val *= -1
end

return tree
Expand Down
Loading
Loading