Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use functions returning NaN on branch cuts instead of abs (issue #109) #123

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import .OperatorsModule:
square,
cube,
pow,
pow_abs,
safe_pow,
div,
log_abs,
log2_abs,
log10_abs,
log1p_abs,
sqrt_abs,
acosh_abs,
safe_log,
safe_log2,
safe_log10,
safe_log1p,
safe_sqrt,
safe_acosh,
neg,
greater,
greater,
Expand Down
26 changes: 21 additions & 5 deletions src/Equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,40 @@ function copy_node(tree::Node{T})::Node{T} where {T}
end
end

const OP_NAMES = Dict(
"safe_log" => "log",
"safe_log2" => "log2",
"safe_log10" => "log10",
"safe_log1p" => "log1p",
"safe_acosh" => "acosh",
"safe_sqrt" => "sqrt",
"safe_pow" => "^",
)

function get_op_name(op::String)
return get(OP_NAMES, op, op)
end

function string_op(
op::F,
tree::Node,
options::Options;
bracketed::Bool=false,
varMap::Union{Array{String,1},Nothing}=nothing,
)::String where {F}
if op in [+, -, *, /, ^]
op_name = get_op_name(string(op))
if op_name in ["+", "-", "*", "/", "^"]
l = string_tree(tree.l, options; bracketed=false, varMap=varMap)
r = string_tree(tree.r, options; bracketed=false, varMap=varMap)
if bracketed
return "$l $(string(op)) $r"
return "$l $op_name $r"
else
return "($l $(string(op)) $r)"
return "($l $op_name $r)"
end
else
l = string_tree(tree.l, options; bracketed=true, varMap=varMap)
r = string_tree(tree.r, options; bracketed=true, varMap=varMap)
return "$(string(op))($l, $r)"
return "$op_name($l, $r)"
end
end

Expand Down Expand Up @@ -250,7 +265,8 @@ function string_tree(
end
end
elseif tree.degree == 1
return "$(options.unaops[tree.op])($(string_tree(tree.l, options, bracketed=true, varMap=varMap)))"
op_name = get_op_name(string(options.unaops[tree.op]))
return "$(op_name)($(string_tree(tree.l, options, bracketed=true, varMap=varMap)))"
else
return string_op(
options.binops[tree.op], tree, options; bracketed=bracketed, varMap=varMap
Expand Down
57 changes: 35 additions & 22 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,41 @@ end
function cube(x::T)::T where {T<:Real}
return x^3
end
function pow_abs(x::T, y::T)::T where {T<:Real}
return abs(x)^y
function safe_pow(x::T, y::T)::T where {T<:Real}
if isinteger(y)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
y < T(0) && x == T(0) && return T(NaN)
else
y > T(0) && x < T(0) && return T(NaN)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
y < T(0) && x <= T(0) && return T(NaN)
end
return x^y
end
function div(x::T, y::T)::T where {T<:Real}
return x / y
end
function log_abs(x::T)::T where {T<:Real}
return log(abs(x) + convert(T, 1//100000000))
function safe_log(x::T)::T where {T<:Real}
x <= T(0) && return T(NaN)
return log(x)
end
function log2_abs(x::T)::T where {T<:Real}
return log2(abs(x) + convert(T, 1//100000000))
function safe_log2(x::T)::T where {T<:Real}
x <= T(0) && return T(NaN)
return log2(x)
end
function log10_abs(x::T)::T where {T<:Real}
return log10(abs(x) + convert(T, 1//100000000))
function safe_log10(x::T)::T where {T<:Real}
x <= T(0) && return T(NaN)
return log10(x)
end
function log1p_abs(x::T)::T where {T<:Real}
return log(abs(x) + convert(T, 1))
function safe_log1p(x::T)::T where {T<:Real}
x <= T(-1) && return T(NaN)
return log1p(x)
end
function acosh_abs(x::T)::T where {T<:Real}
return acosh(abs(x) + convert(T, 1))
function safe_acosh(x::T)::T where {T<:Real}
x < T(1) && return T(NaN)
return acosh(x)
end
function safe_sqrt(x::T)::T where {T<:Real}
x < T(0) && return T(NaN)
return sqrt(x)
end

# Generics:
Expand All @@ -65,17 +80,15 @@ cube(x) = x * x * x
plus(x, y) = x + y
sub(x, y) = x - y
mult(x, y) = x * y
pow_abs(x, y) = abs(x)^y
safe_pow(x, y) = x^y
div(x, y) = x / y
log_abs(x) = log(abs(x) + 1//100000000)
log2_abs(x) = log2(abs(x) + 1//100000000)
log10_abs(x) = log10(abs(x) + 1//100000000)
log1p_abs(x) = log(abs(x) + 1)
acosh_abs(x) = acosh(abs(x) + 1)
safe_log(x) = log(x)
safe_log2(x) = log2(x)
safe_log10(x) = log10(x)
safe_log1p(x) = log1p(x)
safe_acosh(x) = acosh(x)
safe_sqrt(x) = sqrt(x)

function sqrt_abs(x::T)::T where {T}
return sqrt(abs(x))
end
function neg(x::T)::T where {T}
return -x
end
Expand All @@ -100,6 +113,6 @@ function logical_and(x::T, y::T)::T where {T}
end

# Deprecated operations:
@deprecate pow pow_abs
@deprecate pow safe_pow

end
42 changes: 21 additions & 21 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import Zygote: gradient
import ..OperatorsModule:
plus,
pow,
pow_abs,
safe_pow,
mult,
sub,
div,
log_abs,
log10_abs,
log2_abs,
log1p_abs,
sqrt_abs,
acosh_abs,
safe_log,
safe_log10,
safe_log2,
safe_log1p,
safe_sqrt,
safe_acosh,
atanh_clip
import ..EquationModule: Node, string_tree
import ..OptionsStructModule: Options, ComplexityMapping
Expand Down Expand Up @@ -93,26 +93,26 @@ function binopmap(op)
elseif op == div
return /
elseif op == ^
return pow_abs
return safe_pow
elseif op == pow
return pow_abs
return safe_pow
end
return op
end

function unaopmap(op)
if op == log
return log_abs
return safe_log
elseif op == log10
return log10_abs
return safe_log10
elseif op == log2
return log2_abs
return safe_log2
elseif op == log1p
return log1p_abs
return safe_log1p
elseif op == sqrt
return sqrt_abs
return safe_sqrt
elseif op == acosh
return acosh_abs
return safe_acosh
elseif op == atanh
return atanh_clip
end
Expand All @@ -127,11 +127,11 @@ The current arguments have been tuned using the median values from
https://github.com/MilesCranmer/PySR/discussions/115.

# Arguments
- `binary_operators`: Tuple of binary
operators to use. Each operator should be defined for two input scalars,
and one output scalar. All operators need to be defined over the entire
real line (excluding infinity - these are stopped before they are input).
Thus, `log` should be replaced with `log_abs`, etc.
- `binary_operators`: Tuple of binary operators to use. Each operator should
be defined for two input scalars, and one output scalar. All operators
need to be defined over the entire real line (excluding infinity - these
are stopped before they are input), or return `NaN` where not defined.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding this!

Thus, `log` should be replaced with `safe_log`, etc.
For speed, define it so it takes two reals
of the same type as input, and outputs the same type. For the SymbolicUtils
simplification backend, you will need to define a generic method of the
Expand Down Expand Up @@ -532,7 +532,7 @@ function Options(;
end

for (op, f) in enumerate(map(Symbol, binary_operators))
_f = if f in [Symbol(pow), Symbol(pow_abs)]
_f = if f in [Symbol(pow), Symbol(safe_pow)]
Symbol(^)
else
f
Expand Down
28 changes: 14 additions & 14 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ export Population,
square,
cube,
pow,
pow_abs,
safe_pow,
div,
log_abs,
log2_abs,
log10_abs,
log1p_abs,
acosh_abs,
sqrt_abs,
safe_log,
safe_log2,
safe_log10,
safe_log1p,
safe_acosh,
safe_sqrt,
neg,
greater,
relu,
Expand Down Expand Up @@ -123,14 +123,14 @@ import .CoreModule:
square,
cube,
pow,
pow_abs,
safe_pow,
div,
log_abs,
log2_abs,
log10_abs,
log1p_abs,
sqrt_abs,
acosh_abs,
safe_log,
safe_log2,
safe_log10,
safe_log1p,
safe_sqrt,
safe_acosh,
neg,
greater,
greater,
Expand Down
31 changes: 21 additions & 10 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ using SymbolicRegression:
mult,
square,
cube,
safe_pow,
div,
log_abs,
log2_abs,
log10_abs,
sqrt_abs,
acosh_abs,
safe_log,
safe_log2,
safe_log10,
safe_sqrt,
safe_acosh,
neg,
greater,
greater,
Expand All @@ -26,17 +27,27 @@ types_to_test = [Float16, Float32, Float64, BigFloat]
for T in types_to_test
val = T(0.5)
val2 = T(3.2)
@test sqrt_abs(val) == sqrt_abs(-val)
@test abs(log_abs(-val) - log(val)) < 1e-6
@test abs(log2_abs(-val) - log2(val)) < 1e-6
@test abs(log10_abs(-val) - log10(val)) < 1e-6
@test abs(safe_log(val) - log(val)) < 1e-6
@test isnan(safe_log(-val))
@test abs(safe_log2(val) - log2(val)) < 1e-6
@test isnan(safe_log2(-val))
@test abs(safe_log10(val) - log10(val)) < 1e-6
@test isnan(safe_log10(-val))
@test abs(safe_acosh(val2) - acosh(val2)) < 1e-6
@test isnan(safe_acosh(-val2))
@test neg(-val) == val
@test sqrt_abs(val) == sqrt(val)
@test safe_sqrt(val) == sqrt(val)
@test isnan(safe_sqrt(-val))
@test mult(val, val2) == val * val2
@test plus(val, val2) == val + val2
@test sub(val, val2) == val - val2
@test square(val) == val * val
@test cube(val) == val * val * val
@test isnan(safe_pow(T(0.0), -T(1.0)))
@test isnan(safe_pow(-val, val2))
@test all(isnan.([safe_pow(-val, -val2), safe_pow(T(0.0), -val2)]))
@test abs(safe_pow(val, val2) - val^val2) < 1e-6
@test abs(safe_pow(val, -val2) - val^(-val2)) < 1e-6
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
@test div(val, val2) == val / val2
@test greater(val, val2) == T(0.0)
@test greater(val2, val) == T(1.0)
Expand Down
16 changes: 16 additions & 0 deletions test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,19 @@ EquationSearch(
s = repr(tree)
true_s = "((sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0)"
@test s == true_s

for unaop in [safe_log, safe_log2, safe_log10, safe_log1p, safe_sqrt, safe_acosh]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding!

opts = Options(;
default_params..., binary_operators=(+, *, /, -), unary_operators=(unaop,)
)
minitree = Node(1, Node("x1"))
@test string_tree(minitree, opts) == replace(string(unaop), "safe_" => "") * "(x1)"
end

for binop in [safe_pow, ^]
opts = Options(;
default_params..., binary_operators=(+, *, /, -, binop), unary_operators=(cos,)
)
minitree = Node(5, Node("x1"), Node("x2"))
@test string_tree(minitree, opts) == "(x1 ^ x2)"
end
Loading