Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: upgrade to SymbolicUtils w/ fast terms #720

Merged
merged 27 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2b6317c
updates for Add Mul Pow
shashi Jan 6, 2021
f8279d8
.op -> operation; .args -> arguments
shashi Jan 6, 2021
39506eb
Fix calculate_tgrad
YingboMa Jan 6, 2021
ba22d8e
Add `TermCombination` overloads
YingboMa Jan 6, 2021
f358076
Fix vars
YingboMa Jan 6, 2021
b8370a4
Fix first order transform
YingboMa Jan 6, 2021
9db1674
Update mass matrix test
YingboMa Jan 6, 2021
9b7c580
fix diff tests
shashi Jan 6, 2021
3fc1e28
fix toexpr
shashi Jan 6, 2021
addfdf6
Fix a typo
YingboMa Jan 6, 2021
f15696a
Use istree in `diff2term`
YingboMa Jan 7, 2021
a9f3a99
fix pow on termcombination and use similarterm to avoid defining <_e
shashi Jan 8, 2021
aa4e3e2
Merge remote-tracking branch 'origin/s/addmulpow' into s/addmulpow
shashi Jan 8, 2021
388691a
fix typo
shashi Jan 8, 2021
896fdb7
approx couple of tests
shashi Jan 8, 2021
39b057f
pass only RHS in generate_affect_function
shashi Jan 8, 2021
0d70c3b
Fix https://github.com/SciML/ModelingToolkit.jl/issues/609
YingboMa Jan 8, 2021
b2a5f9d
Fix LaTeX printing
YingboMa Jan 8, 2021
828d461
some more genericness
shashi Jan 9, 2021
4a6e57a
Add canonicalexpr to generate safe expressions for evaluations
YingboMa Jan 9, 2021
27c923c
fix more
shashi Jan 9, 2021
d50d987
more fixes
shashi Jan 9, 2021
bd22a92
Update build_targets tests
YingboMa Jan 8, 2021
ad372b0
bump Symutils
shashi Jan 9, 2021
0f9baea
Fix latexify test
YingboMa Jan 11, 2021
4473574
Add `canonicalize` kw and fix latexify tests
YingboMa Jan 11, 2021
8f430f9
Comment out Latexify tests
YingboMa Jan 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
using RecursiveArrayTools

import SymbolicUtils
import SymbolicUtils: Term, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm

import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint

using LinearAlgebra: LU, BlasInt

Expand Down Expand Up @@ -72,13 +74,23 @@ Base.show(io::IO, n::Num) = show_numwrap[] ? print(io, :(Num($(value(n))))) : Ba

Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num
Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num
Base.getproperty(t::Term, f::Symbol) = f === :op ? operation(t) : f === :args ? arguments(t) : getfield(t, f)
function Base.getproperty(t::Union{Add, Mul, Pow, Term}, f::Symbol)
if f === :op
Base.depwarn("`x.op` is deprecated, use `operation(x)` instead", :getproperty, force=true)
operation(t)
elseif f === :args
Base.depwarn("`x.args` is deprecated, use `arguments(x)` instead", :getproperty, force=true)
arguments(t)
else
getfield(t, f)
end
end
<ₑ(s::Num, x) = value(s) <ₑ value(x)
<ₑ(s, x::Num) = value(s) <ₑ value(x)
<ₑ(s::Num, x::Num) = value(s) <ₑ value(x)

for T in (Integer, Rational)
@eval Base.:(^)(n::Num, i::$T) = Num(Term{symtype(n)}(^, [value(n),i]))
@eval Base.:(^)(n::Num, i::$T) = Num(value(n)^i)
end

macro num_method(f, expr, Ts=nothing)
Expand Down
8 changes: 4 additions & 4 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,25 +479,25 @@ get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
function numbered_expr(O::Union{Term,Sym},args...;varordering = args[1],offset = 0,
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
O = value(O)
if O isa Sym || isa(O.op, Sym)
if O isa Sym || isa(operation(O), Sym)
for j in 1:length(args)
i = get_varnumber(O,args[j])
if i !== nothing
return :($(rhsnames[j])[$(i+offset)])
end
end
end
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(O.op),
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(operation(O)),
[numbered_expr(x,args...;offset=offset,lhsname=lhsname,
rhsnames=rhsnames,varordering=varordering) for x in O.args]...)
rhsnames=rhsnames,varordering=varordering) for x in arguments(O)]...)
end

function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)

varordering = value.(args[1])
var = var_from_nested_derivative(de.lhs)[1]
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : x.op, escape=false), tosymbol(var, escape=false)),varordering)
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : operation(x), escape=false), tosymbol(var, escape=false)),varordering)
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
varordering = varordering,
lhsname = lhsname,
Expand Down
6 changes: 5 additions & 1 deletion src/context_dsl.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import SymbolicUtils: symtype
import SymbolicUtils: symtype, term
struct Parameter{T} end

isparameter(x) = false
isparameter(::Sym{<:Parameter}) = true
isparameter(::Sym{<:FnType{<:Any, <:Parameter}}) = true

SymbolicUtils.@number_methods(Sym{Parameter{Real}},
term(f, a),
term(f, a, b), skipbasics)

SymbolicUtils.symtype(s::Symbolic{Parameter{T}}) where T = T
SymbolicUtils.similarterm(t::Term{T}, f, args) where {T<:Parameter} = Term{T}(f, args)

Expand Down
70 changes: 36 additions & 34 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)

_isfalse(occ::Bool) = occ === false
_isfalse(occ::Term) = _isfalse(occ.op)
_isfalse(occ::Term) = _isfalse(operation(occ))

function occursin_info(x, expr::Term)
function occursin_info(x, expr)
!istree(expr) && return false
if isequal(x, expr)
true
else
args = map(a->occursin_info(x, a), expr.args)
args = map(a->occursin_info(x, a), arguments(expr))
if all(_isfalse, args)
return false
end
Expand All @@ -52,66 +53,65 @@ function occursin_info(x, expr::Sym)
isequal(x, expr)
end

hasderiv(O::Term) = O.op isa Differential || any(hasderiv, O.args)
hasderiv(O) = false

occursin_info(x, y) = false
function hasderiv(O)
istree(O) ? operation(O) isa Differential || any(hasderiv, arguments(O)) : false
end
"""
$(SIGNATURES)

TODO
"""
function expand_derivatives(O::Term, simplify=true; occurances=nothing)
if isa(O.op, Differential)
@assert length(O.args) == 1
arg = expand_derivatives(O.args[1], false)
function expand_derivatives(O::Symbolic, simplify=false; occurances=nothing)
if istree(O) && isa(operation(O), Differential)
@assert length(arguments(O)) == 1
arg = expand_derivatives(arguments(O)[1], false)

if occurances == nothing
occurances = occursin_info(O.op.x, arg)
occurances = occursin_info(operation(O).x, arg)
end

_isfalse(occurances) && return 0
occurances isa Bool && return 1 # means it's a `true`

(D, o) = (O.op, arg)
D = operation(O)

if !isa(o, Term)
return O # Cannot expand
elseif isa(o.op, Sym)
return O # Cannot expand
elseif isa(o.op, Differential)
if !istree(arg)
return D(arg) # Cannot expand
elseif isa(operation(arg), Sym)
return D(arg) # Cannot expand
elseif isa(operation(arg), Differential)
# The recursive expand_derivatives was not able to remove
# a nested Differential. We can attempt to differentiate the
# inner expression wrt to the outer iv. And leave the
# unexpandable Differential outside.
if isequal(o.op.x, D.x)
return O
if isequal(operation(arg).x, D.x)
return D(arg)
else
inner = expand_derivatives(D(o.args[1]), false)
inner = expand_derivatives(D(arguments(arg)[1]), false)
# if the inner expression is not expandable either, return
if inner isa Term && operation(inner) isa Differential
return O
if istree(inner) && operation(inner) isa Differential
return D(arg)
else
return expand_derivatives(o.op(inner), simplify)
return expand_derivatives(operation(arg)(inner), simplify)
end
end
end

l = length(o.args)
l = length(arguments(arg))
exprs = []
c = 0

for i in 1:l
t2 = expand_derivatives(D(o.args[i]),false, occurances=occurances.args[i])
t2 = expand_derivatives(D(arguments(arg)[i]),false, occurances=arguments(occurances)[i])

x = if _iszero(t2)
t2
elseif _isone(t2)
d = derivative_idx(o, i)
d isa NoDeriv ? D(o) : d
d = derivative_idx(arg, i)
d isa NoDeriv ? D(arg) : d
else
t1 = derivative_idx(o, i)
t1 = t1 isa NoDeriv ? D(o) : t1
t1 = derivative_idx(arg, i)
t1 = t1 isa NoDeriv ? D(arg) : t1
make_operation(*, [t1, t2])
end

Expand All @@ -136,8 +136,8 @@ function expand_derivatives(O::Term, simplify=true; occurances=nothing)
elseif !hasderiv(O)
return O
else
args = map(a->expand_derivatives(a, false), O.args)
O1 = make_operation(O.op, args)
args = map(a->expand_derivatives(a, false), arguments(O))
O1 = make_operation(operation(O), args)
return simplify ? SymbolicUtils.simplify(O1) : O1
end
end
Expand Down Expand Up @@ -176,7 +176,7 @@ chain rule is not applied:
julia> myop = sin(x) * y^2
sin(x()) * y() ^ 2

julia> typeof(myop.op) # Op is multiplication function
julia> typeof(operation(myop)) # Op is multiplication function
typeof(*)

julia> ModelingToolkit.derivative_idx(myop, 1) # wrt. sin(x)
Expand All @@ -187,7 +187,9 @@ sin(x())
```
"""
derivative_idx(O::Any, ::Any) = 0
derivative_idx(O::Term, idx) = derivative(O.op, (O.args...,), Val(idx))
function derivative_idx(O::Symbolic, idx)
istree(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0
end

# Indicate that no derivative is defined.
struct NoDeriv
Expand Down
40 changes: 21 additions & 19 deletions src/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ let

_scalar = one(TermCombination)

linearity_propagator = [
simterm(t, f, args) = Term{Any}(f, args)
linearity_rules = [
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
@rule (~f)(~x::(!isidx)) => _scalar
Expand All @@ -146,7 +147,8 @@ let
else
error("Function of unknown linearity used: ", ~f)
end
end] |> Rewriters.Chain |> Rewriters.Postwalk |> Rewriters.Fixpoint
end]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=simterm))

global hessian_sparsity

Expand All @@ -164,9 +166,9 @@ let
u = map(value, u)
idx(i) = TermCombination(Set([Dict(i=>1)]))
dict = Dict(u .=> idx.(1:length(u)))
found = []
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x)(f)
_sparse(linearity_propagator(f), length(u))
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=simterm)(f)
lp = linearity_propagator(f)
_sparse(lp, length(u))
end
end

Expand Down Expand Up @@ -213,22 +215,22 @@ function sparsehessian(O, vars::AbstractVector; simplify = true)
return H
end

function toexpr(O::Term)
if isa(O.op, Differential)
return :(derivative($(toexpr(O.args[1])),$(toexpr(O.op.x))))
elseif isa(O.op, Sym)
isempty(O.args) && return O.op.name
return Expr(:call, toexpr(O.op), toexpr.(O.args)...)
end
if O.op === (^)
if length(O.args) > 1 && O.args[2] isa Number && O.args[2] < 0
return Expr(:call, ^, Expr(:call, inv, toexpr(O.args[1])), -(O.args[2]))
end
end
return Expr(:call, O.op, toexpr.(O.args)...)
function toexpr(O)
!istree(O) && return O
if isa(operation(O), Differential)
return :(derivative($(toexpr(arguments(O)[1])),$(toexpr(operation(O).x))))
elseif isa(operation(O), Sym)
isempty(arguments(O)) && return operation(O).name
return Expr(:call, toexpr(operation(O)), toexpr.(arguments(O))...)
end
if operation(O) === (^)
if length(arguments(O)) > 1 && arguments(O)[2] isa Number && arguments(O)[2] < 0
return Expr(:call, ^, Expr(:call, inv, toexpr(arguments(O)[1])), -(arguments(O)[2]))
end
end
return Expr(:call, operation(O), toexpr.(arguments(O))...)
end
toexpr(s::Sym) = nameof(s)
toexpr(s) = s

function toexpr(eq::Equation)
Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs))
Expand Down
12 changes: 6 additions & 6 deletions src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...)

rhs = getfield.(eqs, :rhs)
rhs = prettify_expr.(toexpr.(rhs))
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && length(arguments(x)) == 1 ? arguments(x)[1] : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && arguments(x)[1] == :derivative && length(arguments(x)[2].args) == 2 ? :($(Symbol(:d, arguments(x)[2]))/($(Symbol(:d, arguments(x)[2].args[2])))) : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && arguments(x)[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(arguments(x)[2]))\\right)}{d$(Latexify.latexraw(arguments(x)[3]))}" : x, eq) for eq in rhs]

lhs = getfield.(eqs, :lhs)
lhs = prettify_expr.(toexpr.(lhs))
lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && length(arguments(x)) == 1 ? arguments(x)[1] : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && arguments(x)[1] == :derivative && length(arguments(x)[2].args) == 2 ? :($(Symbol(:d, arguments(x)[2]))/($(Symbol(:d, arguments(x)[2].args[2])))) : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && arguments(x)[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(arguments(x)[2]))\\right)}{d$(Latexify.latexraw(arguments(x)[3]))}" : x, eq) for eq in lhs]

return lhs, rhs
end
Expand Down
8 changes: 8 additions & 0 deletions src/linearity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ function Base.:(==)(comb1::TermCombination, comb2::TermCombination)
end
=#

# to make Mul and Add work
Base.:*(::Number, comb::TermCombination) = comb
function Base.:^(comb::TermCombination, ::Number)
isone(comb) && return comb
iszero(comb) && return _scalar
return comb * comb
end

function Base.:+(comb1::TermCombination, comb2::TermCombination)
if isone(comb1) && !iszero(comb2)
return comb2
Expand Down
18 changes: 9 additions & 9 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Generate a function to evaluate the system's equations.
function generate_function end

getname(x::Sym) = nameof(x)
getname(t::Term) = t.op isa Sym ? getname(t.op) : error("Cannot get name of $t")
getname(t::Term) = operation(t) isa Sym ? getname(operation(t)) : error("Cannot get name of $t")

function Base.getproperty(sys::AbstractSystem, name::Symbol)

Expand Down Expand Up @@ -161,7 +161,7 @@ function renamespace(namespace, x::Sym)
end

function renamespace(namespace, x::Term)
renamespace(namespace, x.op)(x.args...)
renamespace(namespace, operation(x))(arguments(x)...)
end

function namespace_variables(sys::AbstractSystem)
Expand Down Expand Up @@ -189,10 +189,10 @@ function namespace_expr(O::Sym,name,ivname)
end

function namespace_expr(O::Term{T},name,ivname) where {T}
if O.op isa Sym
Term{T}(rename(O.op,renamespace(name,O.op.name)),namespace_expr.(O.args,name,ivname))
if operation(O) isa Sym
Term{T}(rename(operation(O),renamespace(name,operation(O).name)),namespace_expr.(arguments(O),name,ivname))
else
Term{T}(O.op,namespace_expr.(O.args,name,ivname))
Term{T}(operation(O),namespace_expr.(arguments(O),name,ivname))
end
end
namespace_expr(O,name,ivname) = O
Expand Down Expand Up @@ -278,11 +278,11 @@ struct AbstractSysToExpr
end
AbstractSysToExpr(sys) = AbstractSysToExpr(sys,states(sys))
function (f::AbstractSysToExpr)(O::Term)
any(isequal(O), f.states) && return O.op.name # variables
if isa(O.op, Sym)
return build_expr(:call, Any[O.op.name; f.(O.args)])
any(isequal(O), f.states) && return operation(O).name # variables
if isa(operation(O), Sym)
return build_expr(:call, Any[operation(O).name; f.(arguments(O))])
end
return build_expr(:call, Any[O.op; f.(O.args)])
return build_expr(:call, Any[operation(O); f.(arguments(O))])
end
(f::AbstractSysToExpr)(x) = toexpr(x)

Loading