Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimensional constraints (v4) #228

Merged
merged 138 commits into from
Jul 21, 2023
Merged

Dimensional constraints (v4) #228

merged 138 commits into from
Jul 21, 2023

Conversation

MilesCranmer
Copy link
Owner

@MilesCranmer MilesCranmer commented Jul 6, 2023

Continuation of #220. This PR also merges the MLJ update and tries to integrate units into the interface somehow.

Here is the old PR message:

This adds what is probably the most-requested feature of all time: dimensional analysis.

https://en.wikipedia.org/wiki/Dimensional_analysis

This PR implements this by creating a Unitful.jl DynamicQuantities.jl extension which defines violates_dimensional_constraints. This function checks if a given equation is dimensionally consistent, given any choice of units for the free constants. (It does this by creating a "wildcard" unit that propagates through an equation until dimensionless input is needed).

By default the loss penalty for violating these constraints is 1000, but it is user configurable (Options.dimensional_constraint_penalty).

Example: Newton's law

using DynamicQuantities
using SymbolicRegression

M = 1e25 .* rand(100) .* u"kg"
m = 100 .* rand(100) .* u"kg"
r = 10_000 .* rand(100) .* u"km"
G = 6.6743e-11u"m^3 * kg^−1 * s^-2"

F = @. (G * M * m / r^2)

X = ustrip.(cat(M, m, r; dims=2))'
y = ustrip.(F)

X_units = dimension.(first.([M, m, r]))  # ["kg", "kg", "m"]
y_units = dimension(first(F))  # "m kg s^-2"

function loss_fnc(prediction, target)
    # Useful loss for large dynamic range
    scatter_loss = abs(log((abs(prediction)+1e-20) / (abs(target)+1e-20)))
    sign_loss = 10 * (sign(prediction) - sign(target))^2
    return scatter_loss + sign_loss
end

options = Options(;
    binary_operators=[+, -, *, /],
    unary_operators=[cos, exp],
    maxsize=30,
    elementwise_loss=loss_fnc,
    adaptive_parsimony_scaling=1000.0,
    npopulations=Threads.nthreads() * 2,
    complexity_of_constants=2,
)

equation_search(
    X,
    y;
    niterations=1000,
    options=options,
    variable_names=["M", "m", "r"],
    X_units=X_units,
    y_units=y_units,
)
pysr_demo_newton_2x_trimmed.mov

Right now the use of this is quite slow. I think this is due to the fact that Unitful.jl defines every unit as a different type, which causes the type inference to encounter a lot of instability when evaluating.

Fixed by switching to DynamicQuantities.jl


TODO:

  • Improve performance.
    • It seems like most of the slowness was from using a try-catch by default to check whether an operator could handle dimension-full input. Now I first check with hasmethod which is significantly faster. Perhaps I will just need to advise users who use dimensional constraints to avoid passing generic operators (unless they want their operators to manipulate dimensions).
    • Probably want to try this trick: https://discourse.julialang.org/t/performance-of-hasmethod-vs-try-catch-on-methoderror/99827/12
  • @nospecialize some functions. Because we can encounter arbitrary equations, we can encounter arbitrary quantity types! This makes multiple dispatch explode.
    • Is this hopeless? Maybe we need a fork of Unitful.jl that has a single generic Quantity{T}, with a value for the dimensions?
  • Switch to using DynamicQuantities.jl
  • Refactor score_func_batch so that batched evaluation can use it too.
  • Move DynamicQuantities.jl into main package (not an extension)
  • Print character indicating units in constants?
  • Scale X by the upreferred re-scalings. e.g., u"cm" should be scaled relative to u"m".
    • Just raising a warning instead
  • Require desired output units. (Need some API specification for this)
  • Update to DynamicQuantities.jl 0.5
  • Update to DynamicQuantities.jl 0.6
  • Make the printout read out the original units the user passes, as right now it is unclear.
    • Create SymbolicDimensions for DynamicQuantities.jl
  • Make the "[.]" not show up when no units are used.
  • New: get units working with MLJ
  • Print warning message that the default dimensional penalty is set. Let user know they can provide a custom value if they wish.
  • Enable easier conversion to SymbolicUtils (or make it obvious where the options can be found)
  • Add dimensional analysis section to docs
  • Print "y_1 [...] = " as well
  • Reach 100% diff coverage
  • Precompilation for MLJ interface
  • Variable level of precompilation (so can be avoided in PySR)

@MilesCranmer MilesCranmer force-pushed the units4 branch 2 times, most recently from 060accd to 0dedf17 Compare July 21, 2023 00:51
src/MLJInterface.jl Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant