Skip to content

Commit

Permalink
Merge pull request #103 from JuliaDiffEq/hg/fix/cleanup
Browse files Browse the repository at this point in the history
Minor cleanup
  • Loading branch information
ChrisRackauckas authored Mar 8, 2019
2 parents 02edf04 + 6379485 commit ab45092
Show file tree
Hide file tree
Showing 15 changed files with 159 additions and 149 deletions.
44 changes: 22 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ to manipulate.
### Example: ODE

Let's build an ODE. First we define some variables. In a differential equation
system, we need to differentiate between our unknown (dependent) variables
system, we need to differentiate between our (dependent) variables
and parameters. Therefore we label them as follows:

```julia
using ModelingToolkit

# Define some variables
@Param t σ ρ β
@Unknown x(t) y(t) z(t)
@Deriv D'~t
@parameters t σ ρ β
@variables x(t) y(t) z(t)
@derivatives D'~t
```

Then we build the system:
Expand Down Expand Up @@ -78,11 +78,11 @@ f = ODEFunction(de)
We can also build nonlinear systems. Let's say we wanted to solve for the steady
state of the previous ODE. This is the nonlinear system defined by where the
derivatives are zero. We use unknown variables for our nonlinear system.
derivatives are zero. We use (unknown) variables for our nonlinear system.
```julia
@Unknown x y z
@Param σ ρ β
@variables x y z
@parameters σ ρ β

# Define a nonlinear system
eqs = [0 ~ σ*(y-x),
Expand Down Expand Up @@ -173,7 +173,7 @@ structure is as follows:
the system of equations.
- Name to subtype mappings: these describe how variable `subtype`s are mapped
to the contexts of the system. For example, for a differential equation,
the unknown variable corresponds to given subtypes and then the `eqs` can
the variable corresponds to given subtypes and then the `eqs` can
be analyzed knowing what the state variables are.
- Variable names which do not fall into one of the system's core subtypes are
treated as intermediates which can be used for holding subcalculations and
Expand Down Expand Up @@ -223,7 +223,7 @@ function via the dispatch:
```julia
# `N` arguments are accepted by the relevant method of `my_function`
ModelingToolkit.Derivative(::typeof(my_function), args::NTuple{N,Any}, ::Val{i})
ModelingToolkit.derivative(::typeof(my_function), args::NTuple{N,Any}, ::Val{i})
```
where `i` means that it's the derivative of the `i`th argument. `args` is the
Expand All @@ -233,7 +233,7 @@ You should return an `Operation` for the derivative of your function.
For example, `sin(t)`'s derivative (by `t`) is given by the following:
```julia
ModelingToolkit.Derivative(::typeof(sin), args::NTuple{1,Any}, ::Val{1}) = cos(args[1])
ModelingToolkit.derivative(::typeof(sin), args::NTuple{1,Any}, ::Val{1}) = cos(args[1])
```
### Macro-free Usage
Expand All @@ -243,31 +243,31 @@ is accessible via a function-based interface. This means that all macros are
syntactic sugar in some form. For example, the variable construction:
```julia
@Param t σ ρ β
@Unknown x(t) y(t) z(t)
@Deriv D'~t
@parameters t σ ρ β
@variables x(t) y(t) z(t)
@derivatives D'~t
```
is syntactic sugar for:
```julia
t = Parameter(:t)
x = Unknown(:x, [t])
y = Unknown(:y, [t])
z = Unknown(:z, [t])
t = Variable(:t; known = true)
x = Variable(:x, [t])
y = Variable(:y, [t])
z = Variable(:z, [t])
D = Differential(t)
σ = Parameter()
ρ = Parameter()
β = Parameter()
σ = Variable(; known = true)
ρ = Variable(; known = true)
β = Variable(; known = true)
```
### Intermediate Calculations
The system building functions can handle intermediate calculations. For example,
```julia
@Unknown x y z
@Param σ ρ β
@variables x y z
@parameters σ ρ β
a = y - x
eqs = [0 ~ σ*a,
0 ~ x*-z)-y,
Expand Down
10 changes: 5 additions & 5 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module ModelingToolkit

export Operation, Expression
export calculate_jacobian, generate_jacobian, generate_function
export @register


using DiffEqBase
using StaticArrays, LinearAlgebra

Expand Down Expand Up @@ -30,9 +35,4 @@ include("function_registration.jl")
include("simplify.jl")
include("utils.jl")

export Operation, Expression, AbstractComponent
export calculate_jacobian, generate_jacobian, generate_function
export ArrayFunction, SArrayFunction
export @register

end # module
17 changes: 9 additions & 8 deletions src/differentials.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
export Differential, expand_derivatives, @derivatives


struct Differential <: Function
x::Expression
end
Expand All @@ -12,7 +15,7 @@ function (D::Differential)(x::Variable)
return Operation(D, Expression[x])
end
(::Differential)(::Any) = Constant(0)
Base.:(==)(D1::Differential, D2::Differential) = D1.x == D2.x
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)

function expand_derivatives(O::Operation)
@. O.args = expand_derivatives(O.args)
Expand All @@ -21,21 +24,21 @@ function expand_derivatives(O::Operation)
D = O.op
o = O.args[1]
isa(o, Operation) || return O
return simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
return simplify_constants(sum(i->derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
end

return O
end
expand_derivatives(x) = x

# Don't specialize on the function here
Derivative(O::Operation, idx) = Derivative(O.op, (O.args...,), Val(idx))
derivative(O::Operation, idx) = derivative(O.op, (O.args...,), Val(idx))

# Pre-defined derivatives
import DiffRules, SpecialFunctions, NaNMath
for (modu, fun, arity) DiffRules.diffrules()
for i 1:arity
@eval function Derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
@eval function derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
M, f = $(modu, fun)
partials = DiffRules.diffrule(M, f, args...)
dx = @static $arity == 1 ? partials : partials[$i]
Expand All @@ -60,7 +63,7 @@ function _differential_macro(x)
lhss = Symbol[]
x = flatten_expr!(x)
for di in x
@assert di isa Expr && di.args[1] == :~ "@Deriv expects a form that looks like `@Deriv D''~t E'~t`"
@assert di isa Expr && di.args[1] == :~ "@derivatives expects a form that looks like `@derivatives D''~t E'~t`"
lhs = di.args[2]
rhs = di.args[3]
order, lhs = count_order(lhs)
Expand All @@ -72,12 +75,10 @@ function _differential_macro(x)
ex
end

macro Deriv(x...)
macro derivatives(x...)
esc(_differential_macro(x))
end

function calculate_jacobian(eqs,vars)
Expression[Differential(vars[j])(eqs[i]) for i in 1:length(eqs), j in 1:length(vars)]
end

export Differential, expand_derivatives, @Deriv, calculate_jacobian
4 changes: 2 additions & 2 deletions src/equations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ struct Equation
lhs::Expression
rhs::Expression
end
Base.:(==)(a::Equation, b::Equation) = (a.lhs, a.rhs) == (b.lhs, b.rhs)
Base.:(==)(a::Equation, b::Equation) = isequal((a.lhs, a.rhs), (b.lhs, b.rhs))

Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)


_is_dependent(x::Variable) = !x.known && !isempty(x.dependents)
_is_parameter(iv) = x -> x.known && x iv
_is_parameter(iv) = x -> x.known && !isequal(x, iv)
_is_known(x::Variable) = x.known
_is_unknown(x::Variable) = !x.known

Expand Down
8 changes: 7 additions & 1 deletion src/function_registration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ for (M, f, arity) in DiffRules.diffrules()
@eval @register $sig
end

for fun = (:<, :>, :(==), :!, :&, :|, :div)
for fun [:!]
basefun = Expr(:., Base, QuoteNode(fun))
sig = :($basefun(x))
@eval @register $sig
end

for fun [:<, :>, :(==), :&, :|, :div]
basefun = Expr(:., Base, QuoteNode(fun))
sig = :($basefun(x,y))
@eval @register $sig
Expand Down
14 changes: 7 additions & 7 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ struct Operation <: Expression
end

# Recursive ==
function Base.:(==)(x::Operation,y::Operation)
function Base.isequal(x::Operation,y::Operation)
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
end
Base.:(==)(::Operation, ::Number ) = false
Base.:(==)(::Number , ::Operation) = false
Base.:(==)(::Operation, ::Variable ) = false
Base.:(==)(::Variable , ::Operation) = false
Base.:(==)(::Operation, ::Constant ) = false
Base.:(==)(::Constant , ::Operation) = false
Base.isequal(::Operation, ::Number ) = false
Base.isequal(::Number , ::Operation) = false
Base.isequal(::Operation, ::Variable ) = false
Base.isequal(::Variable , ::Operation) = false
Base.isequal(::Operation, ::Constant ) = false
Base.isequal(::Constant , ::Operation) = false

Base.convert(::Type{Expr}, O::Operation) =
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])
Expand Down
7 changes: 4 additions & 3 deletions src/simplify.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
export simplify_constants


function simplify_constants(O::Operation, shorten_tree)
while true
O′ = _simplify_constants(O, shorten_tree)
if is_operation(O′)
O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree))
end
O == O′ && return O
isequal(O, O′) && return O
O = O′
end
end
Expand Down Expand Up @@ -72,5 +75,3 @@ function _simplify_constants(O::Operation, shorten_tree)
end
_simplify_constants(x, shorten_tree) = x
_simplify_constants(x) = _simplify_constants(x, true)

export simplify_constants
6 changes: 3 additions & 3 deletions src/systems/diffeqs/diffeqsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function flatten_differential(O::Operation)
@assert is_derivative(O) "invalid differential: $O"
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
(x, t, order) = flatten_differential(O.args[1])
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
return (x, t, order + 1)
end

Expand All @@ -26,7 +26,7 @@ function Base.convert(::Type{DiffEq}, eq::Equation)
(x, t, n) = flatten_differential(eq.lhs)
return DiffEq(x, t, n, eq.rhs)
end
Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs)
Base.:(==)(a::DiffEq, b::DiffEq) = isequal((a.x, a.t, a.n, a.rhs), (b.x, b.t, b.n, b.rhs))
get_args(eq::DiffEq) = Expression[eq.x, eq.t, eq.rhs]

struct DiffEqSystem <: AbstractSystem
Expand Down Expand Up @@ -79,7 +79,7 @@ end
function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction)
jac = calculate_jacobian(sys)

gam = Parameter(:gam)
gam = Variable(:gam; known = true)

W = LinearAlgebra.I - gam*jac
W = SMatrix{size(W,1),size(W,2)}(W)
Expand Down
9 changes: 5 additions & 4 deletions src/systems/diffeqs/first_order_transform.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
export ode_order_lowering


function lower_varname(var::Variable, idv, order)
order == 0 && return var
name = Symbol(var.name, :_, string(idv.name)^order)
return Variable(name, var.known, var.dependents)
return Variable(name, var.dependents; known = var.known)
end

function ode_order_lowering(sys::DiffEqSystem)
Expand All @@ -17,7 +20,7 @@ function ode_order_lowering(eqs, iv)
var, maxorder = eq.x, eq.n
if maxorder > get(var_order, var, 0)
var_order[var] = maxorder
var vars || push!(vars, var)
any(isequal(var), vars) || push!(vars, var)
end
var′ = lower_varname(eq.x, eq.t, eq.n - 1)
rhs′ = rename(eq.rhs)
Expand Down Expand Up @@ -45,5 +48,3 @@ function rename(O::Expression)
end
return Operation(O.op, rename.(O.args))
end

export ode_order_lowering
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ is_derivative(::Any) = false

has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
has_dependent(x::Variable, t::Variable) =
t x.dependents || any(has_dependent(t), x.dependents)
any(isequal(t), x.dependents) || any(has_dependent(t), x.dependents)
Loading

0 comments on commit ab45092

Please sign in to comment.