Skip to content

Commit

Permalink
Clean up units integration
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 21, 2023
1 parent a07d3f6 commit 0dedf17
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 23 deletions.
9 changes: 4 additions & 5 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ end
weights::Union{AbstractVector{T}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
X_units::Union{AbstractVector, Nothing}=nothing,
y_units=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type=Nothing,
X_units::Union{AbstractVector, Nothing}=nothing,
y_units=nothing,
)
Construct a dataset to pass between internal functions.
Expand All @@ -116,12 +116,12 @@ function Dataset(
nfeatures = size(X, FEATURE_DIM)
weighted = weights !== nothing
(variable_names, pretty_variable_names) = if variable_names === nothing
["x$(i)" for i in 1:nfeatures], ["x$(subscriptify(i))" for i in 1:nfeatures]
(["x$(i)" for i in 1:nfeatures], ["x$(subscriptify(i))" for i in 1:nfeatures])
else
(variable_names, variable_names)
end
y_variable_name = if y_variable_name === nothing
"y" variable_names ? "y" : "target"
("y" variable_names) ? "y" : "target"
else
y_variable_name
end
Expand Down Expand Up @@ -154,7 +154,6 @@ function Dataset(
end
X_sym_units = let _X = get_units(T, SD, X_units, sym_uparse)
if _X === nothing && y_sym_units !== nothing
# Make units for X:
get_units(T, SD, [one(T) for _ in 1:nfeatures], sym_uparse)
else
_X
Expand Down
3 changes: 1 addition & 2 deletions src/LossFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ function dimensional_regularization(
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
if !violates_dimensional_constraints(tree, dataset, options)
return zero(L)
end
if options.dimensional_constraint_penalty === nothing
elseif options.dimensional_constraint_penalty === nothing
return L(1000)
else
return L(options.dimensional_constraint_penalty::Float32)
Expand Down
51 changes: 39 additions & 12 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ module MLJInterfaceModule
using Optim: Optim
import MLJModelInterface as MMI
import DynamicExpressions: eval_tree_array, string_tree, Node
import DynamicQuantities as DQ
import DynamicQuantities: AbstractDimensions, DEFAULT_DIM_BASE_TYPE, ustrip, dimension
import DynamicQuantities:
AbstractDimensions,
SymbolicDimensions,
Quantity,
DEFAULT_DIM_BASE_TYPE,
ustrip,
dimension
import LossFunctions: SupervisedLoss
import Compat: allequal, stack
import ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE
Expand Down Expand Up @@ -33,7 +38,7 @@ function modelexpr(model_name::Symbol)
runtests::Bool = true
loss_type::L = Nothing
selection_method::F = choose_best
dimensions_type::Type{D} = DQ.SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
precompiling::Bool = false
end)
# TODO: store `procs` from initial run if parallelism is `:multiprocessing`
Expand Down Expand Up @@ -239,11 +244,11 @@ end

# TODO: Test whether this conversion poses any issues in data normalization...
function dimension_fallback(
q::Union{<:DQ.Quantity{T,<:AbstractDimensions}}, ::Type{D}
q::Union{<:Quantity{T,<:AbstractDimensions}}, ::Type{D}
) where {T,D}
return DQ.dimension(convert(DQ.Quantity{T,D}, q))::D
return dimension(convert(Quantity{T,D}, q))::D
end
dimension_fallback(q::Union{<:DQ.Quantity{T,D}}, ::Type{D}) where {T,D} = DQ.dimension(q)::D
dimension_fallback(q::Union{<:Quantity{T,D}}, ::Type{D}) where {T,D} = dimension(q)::D
dimension_fallback(_, ::Type{D}) where {D} = D()

function unwrap_units_single(A::AbstractMatrix, ::Type{D}) where {D}
Expand All @@ -253,12 +258,12 @@ function unwrap_units_single(A::AbstractMatrix, ::Type{D}) where {D}
error("Inconsistent units in feature $i of matrix.")
end
dims = map(Base.Fix2(dimension_fallback, D) first, eachrow(A))
return stack([DQ.ustrip.(row) for row in eachrow(A)]; dims=1), dims
return stack([ustrip.(row) for row in eachrow(A)]; dims=1), dims
end
function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D}
allequal(Base.Fix2(dimension_fallback, D).(v)) || error("Inconsistent units in vector.")
dims = dimension_fallback(first(v), D)
v = DQ.ustrip(v)
v = ustrip(v)
return v, dims
end

Expand Down Expand Up @@ -473,10 +478,12 @@ eval(
- `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
`Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
expressions will be taken from the column names of `X`, if available.
expressions will be taken from the column names of `X`, if available. Units in columns
of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
- `y` is the target, which can be any `AbstractVector` whose element scitype is
`Continuous`; check the scitype with `scitype(y)`.
`Continuous`; check the scitype with `scitype(y)`. Units in `y` (use `DynamicQuantities`
for units) will trigger dimensional analysis to be used.
- `w` is the observation weights which can either be `nothing` (default) or an
`AbstractVector` whoose element scitype is `Count` or `Continuous`.
Expand Down Expand Up @@ -543,6 +550,24 @@ eval(
println("Equation used:", r.equation_strings[r.best_idx])
```
With units and variable names:
```julia
using MLJ
using DynamicQuantities
SRegressor = @load SRRegressor pkg=SymbolicRegression
X = (; x1=rand(32) .* us"km/h", x2=rand(32) .* us"km")
y = @. X.x2 / X.x1 + 0.5us"h"
model = SRRegressor(binary_operators=[+, -, *, /])
mach = machine(model, X, y)
fit!(mach)
y_hat = predict(mach, X)
# View the equation used:
r = report(mach)
println("Equation used:", r.equation_strings[r.best_idx])
```
See also [`MultitargetSRRegressor`](@ref).
""",
r"^ " => "",
Expand Down Expand Up @@ -574,10 +599,12 @@ eval(
- `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
`Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
expressions will be taken from the column names of `X`, if available.
expressions will be taken from the column names of `X`, if available. Units in columns
of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
- `y` is the target, which can be any table of target variables whose element
scitype is `Continuous`; check the scitype with `schema(y)`.
scitype is `Continuous`; check the scitype with `schema(y)`. Units in columns of
`y` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
- `w` is the observation weights which can either be `nothing` (default) or an
`AbstractVector` whoose element scitype is `Count` or `Continuous`. The same
Expand Down
9 changes: 5 additions & 4 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,8 @@ which is useful for debugging and profiling.
to be used for dimensional constraints. For example, if `X_units=["kg", "m"]`,
then the first feature will have units of kilograms, and the second will
have units of meters.
- `y_units=nothing`:
The units of the output, to be used for dimensional constraints. If
`y` is a matrix, then this can be a vector of units, in which case
- `y_units=nothing`: The units of the output, to be used for dimensional constraints.
If `y` is a matrix, then this can be a vector of units, in which case
each element corresponds to each output feature.
# Returns
Expand Down Expand Up @@ -349,6 +348,7 @@ function equation_search(
variable_names = deprecate_varmap(variable_names, varMap, :equation_search)

if weights !== nothing
@assert length(weights) == length(y)
weights = reshape(weights, size(y))
end
if T <: Complex && loss_type == Nothing
Expand Down Expand Up @@ -431,6 +431,7 @@ function equation_search(
"`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
)

# TODO: Still not type stable. Should be able to pass `Val{return_state}`.
should_return_state = if options.return_state === nothing
return_state === nothing ? false : return_state
else
Expand Down Expand Up @@ -487,7 +488,7 @@ function _equation_search(
end
end
if any(d -> d.X_units !== nothing || d.y_units !== nothing, datasets)
if options.dimensional_constraint_penalty === nothing
if options.dimensional_constraint_penalty === nothing && saved_state === nothing
@warn "You are using dimensional constraints, but `dimensional_constraint_penalty` was not set. The default penalty of `1000.0` will be used."
end
end
Expand Down
7 changes: 7 additions & 0 deletions test/test_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ end
report = MLJ.report(mach)
@test minimum(report.losses[1]) < 1e-7
@test minimum(report.losses[2]) < 1e-7

# Repeat with second run:
mach.model.niterations = 0
MLJ.fit!(mach)
report = MLJ.report(mach)
@test minimum(report.losses[1]) < 1e-7
@test minimum(report.losses[2]) < 1e-7
end
end

Expand Down

0 comments on commit 0dedf17

Please sign in to comment.