Skip to content

Commit

Permalink
Merge pull request #720 from SciML/s/addmulpow
Browse files Browse the repository at this point in the history
WIP: upgrade to SymbolicUtils w/ fast terms
  • Loading branch information
YingboMa authored Jan 12, 2021
2 parents 17d02bb + 8f430f9 commit 135821e
Show file tree
Hide file tree
Showing 26 changed files with 281 additions and 189 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ RuntimeGeneratedFunctions = "0.4, 0.5"
SafeTestsets = "0.0.1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicUtils = "0.6"
SymbolicUtils = "0.7"
TreeViews = "0.3"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand Down
18 changes: 15 additions & 3 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
using RecursiveArrayTools

import SymbolicUtils
import SymbolicUtils: Term, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm

import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint

using LinearAlgebra: LU, BlasInt

Expand Down Expand Up @@ -72,13 +74,23 @@ Base.show(io::IO, n::Num) = show_numwrap[] ? print(io, :(Num($(value(n))))) : Ba

Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num
Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num
Base.getproperty(t::Term, f::Symbol) = f === :op ? operation(t) : f === :args ? arguments(t) : getfield(t, f)
function Base.getproperty(t::Union{Add, Mul, Pow, Term}, f::Symbol)
if f === :op
Base.depwarn("`x.op` is deprecated, use `operation(x)` instead", :getproperty, force=true)
operation(t)
elseif f === :args
Base.depwarn("`x.args` is deprecated, use `arguments(x)` instead", :getproperty, force=true)
arguments(t)
else
getfield(t, f)
end
end
<(s::Num, x) = value(s) <value(x)
<(s, x::Num) = value(s) <value(x)
<(s::Num, x::Num) = value(s) <value(x)

for T in (Integer, Rational)
@eval Base.:(^)(n::Num, i::$T) = Num(Term{symtype(n)}(^, [value(n),i]))
@eval Base.:(^)(n::Num, i::$T) = Num(value(n)^i)
end

macro num_method(f, expr, Ts=nothing)
Expand Down
10 changes: 5 additions & 5 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,28 +476,28 @@ end

get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)

function numbered_expr(O::Union{Term,Sym},args...;varordering = args[1],offset = 0,
function numbered_expr(O::Symbolic,args...;varordering = args[1],offset = 0,
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
O = value(O)
if O isa Sym || isa(O.op, Sym)
if O isa Sym || isa(operation(O), Sym)
for j in 1:length(args)
i = get_varnumber(O,args[j])
if i !== nothing
return :($(rhsnames[j])[$(i+offset)])
end
end
end
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(O.op),
return Expr(:call, O isa Sym ? tosymbol(O, escape=false) : Symbol(operation(O)),
[numbered_expr(x,args...;offset=offset,lhsname=lhsname,
rhsnames=rhsnames,varordering=varordering) for x in O.args]...)
rhsnames=rhsnames,varordering=varordering) for x in arguments(O)]...)
end

function numbered_expr(de::ModelingToolkit.Equation,args...;varordering = args[1],
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)

varordering = value.(args[1])
var = var_from_nested_derivative(de.lhs)[1]
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : x.op, escape=false), tosymbol(var, escape=false)),varordering)
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : operation(x), escape=false), tosymbol(var, escape=false)),varordering)
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,args...;offset=offset,
varordering = varordering,
lhsname = lhsname,
Expand Down
6 changes: 5 additions & 1 deletion src/context_dsl.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import SymbolicUtils: symtype
import SymbolicUtils: symtype, term
struct Parameter{T} end

isparameter(x) = false
isparameter(::Sym{<:Parameter}) = true
isparameter(::Sym{<:FnType{<:Any, <:Parameter}}) = true

SymbolicUtils.@number_methods(Sym{Parameter{Real}},
term(f, a),
term(f, a, b), skipbasics)

SymbolicUtils.symtype(s::Symbolic{Parameter{T}}) where T = T
SymbolicUtils.similarterm(t::Term{T}, f, args) where {T<:Parameter} = Term{T}(f, args)

Expand Down
70 changes: 36 additions & 34 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)

_isfalse(occ::Bool) = occ === false
_isfalse(occ::Term) = _isfalse(occ.op)
_isfalse(occ::Term) = _isfalse(operation(occ))

function occursin_info(x, expr::Term)
function occursin_info(x, expr)
!istree(expr) && return false
if isequal(x, expr)
true
else
args = map(a->occursin_info(x, a), expr.args)
args = map(a->occursin_info(x, a), arguments(expr))
if all(_isfalse, args)
return false
end
Expand All @@ -52,66 +53,65 @@ function occursin_info(x, expr::Sym)
isequal(x, expr)
end

hasderiv(O::Term) = O.op isa Differential || any(hasderiv, O.args)
hasderiv(O) = false

occursin_info(x, y) = false
function hasderiv(O)
istree(O) ? operation(O) isa Differential || any(hasderiv, arguments(O)) : false
end
"""
$(SIGNATURES)
TODO
"""
function expand_derivatives(O::Term, simplify=true; occurances=nothing)
if isa(O.op, Differential)
@assert length(O.args) == 1
arg = expand_derivatives(O.args[1], false)
function expand_derivatives(O::Symbolic, simplify=false; occurances=nothing)
if istree(O) && isa(operation(O), Differential)
@assert length(arguments(O)) == 1
arg = expand_derivatives(arguments(O)[1], false)

if occurances == nothing
occurances = occursin_info(O.op.x, arg)
occurances = occursin_info(operation(O).x, arg)
end

_isfalse(occurances) && return 0
occurances isa Bool && return 1 # means it's a `true`

(D, o) = (O.op, arg)
D = operation(O)

if !isa(o, Term)
return O # Cannot expand
elseif isa(o.op, Sym)
return O # Cannot expand
elseif isa(o.op, Differential)
if !istree(arg)
return D(arg) # Cannot expand
elseif isa(operation(arg), Sym)
return D(arg) # Cannot expand
elseif isa(operation(arg), Differential)
# The recursive expand_derivatives was not able to remove
# a nested Differential. We can attempt to differentiate the
# inner expression wrt to the outer iv. And leave the
# unexpandable Differential outside.
if isequal(o.op.x, D.x)
return O
if isequal(operation(arg).x, D.x)
return D(arg)
else
inner = expand_derivatives(D(o.args[1]), false)
inner = expand_derivatives(D(arguments(arg)[1]), false)
# if the inner expression is not expandable either, return
if inner isa Term && operation(inner) isa Differential
return O
if istree(inner) && operation(inner) isa Differential
return D(arg)
else
return expand_derivatives(o.op(inner), simplify)
return expand_derivatives(operation(arg)(inner), simplify)
end
end
end

l = length(o.args)
l = length(arguments(arg))
exprs = []
c = 0

for i in 1:l
t2 = expand_derivatives(D(o.args[i]),false, occurances=occurances.args[i])
t2 = expand_derivatives(D(arguments(arg)[i]),false, occurances=arguments(occurances)[i])

x = if _iszero(t2)
t2
elseif _isone(t2)
d = derivative_idx(o, i)
d isa NoDeriv ? D(o) : d
d = derivative_idx(arg, i)
d isa NoDeriv ? D(arg) : d
else
t1 = derivative_idx(o, i)
t1 = t1 isa NoDeriv ? D(o) : t1
t1 = derivative_idx(arg, i)
t1 = t1 isa NoDeriv ? D(arg) : t1
make_operation(*, [t1, t2])
end

Expand All @@ -136,8 +136,8 @@ function expand_derivatives(O::Term, simplify=true; occurances=nothing)
elseif !hasderiv(O)
return O
else
args = map(a->expand_derivatives(a, false), O.args)
O1 = make_operation(O.op, args)
args = map(a->expand_derivatives(a, false), arguments(O))
O1 = make_operation(operation(O), args)
return simplify ? SymbolicUtils.simplify(O1) : O1
end
end
Expand Down Expand Up @@ -176,7 +176,7 @@ chain rule is not applied:
julia> myop = sin(x) * y^2
sin(x()) * y() ^ 2
julia> typeof(myop.op) # Op is multiplication function
julia> typeof(operation(myop)) # Op is multiplication function
typeof(*)
julia> ModelingToolkit.derivative_idx(myop, 1) # wrt. sin(x)
Expand All @@ -187,7 +187,9 @@ sin(x())
```
"""
derivative_idx(O::Any, ::Any) = 0
derivative_idx(O::Term, idx) = derivative(O.op, (O.args...,), Val(idx))
function derivative_idx(O::Symbolic, idx)
istree(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0
end

# Indicate that no derivative is defined.
struct NoDeriv
Expand Down
86 changes: 60 additions & 26 deletions src/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ let

_scalar = one(TermCombination)

linearity_propagator = [
simterm(t, f, args) = Term{Any}(f, args)
linearity_rules = [
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
@rule (~f)(~x::(!isidx)) => _scalar
Expand All @@ -146,7 +147,8 @@ let
else
error("Function of unknown linearity used: ", ~f)
end
end] |> Rewriters.Chain |> Rewriters.Postwalk |> Rewriters.Fixpoint
end]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=simterm))

global hessian_sparsity

Expand All @@ -164,9 +166,9 @@ let
u = map(value, u)
idx(i) = TermCombination(Set([Dict(i=>1)]))
dict = Dict(u .=> idx.(1:length(u)))
found = []
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x)(f)
_sparse(linearity_propagator(f), length(u))
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=simterm)(f)
lp = linearity_propagator(f)
_sparse(lp, length(u))
end
end

Expand Down Expand Up @@ -213,28 +215,60 @@ function sparsehessian(O, vars::AbstractVector; simplify = true)
return H
end

function toexpr(O::Term)
if isa(O.op, Differential)
return :(derivative($(toexpr(O.args[1])),$(toexpr(O.op.x))))
elseif isa(O.op, Sym)
isempty(O.args) && return O.op.name
return Expr(:call, toexpr(O.op), toexpr.(O.args)...)
end
if O.op === (^)
if length(O.args) > 1 && O.args[2] isa Number && O.args[2] < 0
return Expr(:call, ^, Expr(:call, inv, toexpr(O.args[1])), -(O.args[2]))
end
end
return Expr(:call, O.op, toexpr.(O.args)...)
"""
toexpr(O::Union{Symbolics,Num,Equation,AbstractArray}; canonicalize=true) -> Expr
Convert `Symbolics` into `Expr`. If `canonicalize`, then we turn exprs like
`x^(-n)` into `inv(x)^n` to avoid type error when evaluating.
"""
function toexpr(O; canonicalize=true)
if canonicalize
canonical, O = canonicalexpr(O)
canonical && return O
else
!istree(O) && return O
end

op = operation(O)
args = arguments(O)
if op isa Differential
return :(derivative($(toexpr(args[1]; canonicalize=canonicalize)),$(toexpr(op.x; canonicalize=canonicalize))))
elseif op isa Sym
isempty(args) && return nameof(op)
return Expr(:call, toexpr(op; canonicalize=canonicalize), toexpr(args; canonicalize=canonicalize)...)
end
return Expr(:call, op, toexpr(args; canonicalize=canonicalize)...)
end
toexpr(s::Sym; kw...) = nameof(s)

"""
canonicalexpr(O) -> (canonical::Bool, expr)
Canonicalize `O`. Return `canonical` if `expr` is valid code to generate.
"""
function canonicalexpr(O)
!istree(O) && return true, O
op = operation(O)
args = arguments(O)
if op === (^)
if length(args) == 2 && args[2] isa Number && args[2] < 0
ex = toexpr(args[1])
if args[2] == -1
expr = Expr(:call, inv, ex)
else
expr = Expr(:call, ^, Expr(:call, inv, ex), -args[2])
end
return true, expr
end
end
return false, O
end
toexpr(s::Sym) = nameof(s)
toexpr(s) = s

function toexpr(eq::Equation)
Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs))
function toexpr(eq::Equation; kw...)
Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...))
end

toexpr(eq::AbstractArray) = toexpr.(eq)
toexpr(x::Integer) = x
toexpr(x::AbstractFloat) = x
toexpr(x::Num) = toexpr(value(x))
toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs)
toexpr(x::Integer; kw...) = x
toexpr(x::AbstractFloat; kw...) = x
toexpr(x::Num; kw...) = toexpr(value(x); kw...)
4 changes: 2 additions & 2 deletions src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...)
# that latexify can deal with

rhs = getfield.(eqs, :rhs)
rhs = prettify_expr.(toexpr.(rhs))
rhs = prettify_expr.(toexpr(rhs; canonicalize=false))
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs]

lhs = getfield.(eqs, :lhs)
lhs = prettify_expr.(toexpr.(lhs))
lhs = prettify_expr.(toexpr(lhs; canonicalize=false))
lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs]
Expand Down
8 changes: 8 additions & 0 deletions src/linearity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ function Base.:(==)(comb1::TermCombination, comb2::TermCombination)
end
=#

# to make Mul and Add work
Base.:*(::Number, comb::TermCombination) = comb
function Base.:^(comb::TermCombination, ::Number)
isone(comb) && return comb
iszero(comb) && return _scalar
return comb * comb
end

function Base.:+(comb1::TermCombination, comb2::TermCombination)
if isone(comb1) && !iszero(comb2)
return comb2
Expand Down
Loading

0 comments on commit 135821e

Please sign in to comment.