Skip to content

Commit

Permalink
Add canonicalize kw and fix latexify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Jan 11, 2021
1 parent 0f9baea commit 4322277
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
43 changes: 29 additions & 14 deletions src/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,21 +215,31 @@ function sparsehessian(O, vars::AbstractVector; simplify = true)
return H
end

function toexpr(O)
canonical, O = canonicalexpr(O)
canonical && return O
"""
toexpr(O::Union{Symbolics,Num,Equation,AbstractArray}; canonicalize=true) -> Expr
Convert `Symbolics` into `Expr`. If `canonicalize`, then we turn exprs like
`x^(-n)` into `inv(x)^n` to avoid type error when evaluating.
"""
function toexpr(O; canonicalize=true)
if canonicalize
canonical, O = canonicalexpr(O)
canonical && return O
else
!istree(O) && return O
end

op = operation(O)
args = arguments(O)
if op isa Differential
return :(derivative($(toexpr(args[1])),$(toexpr(op.x))))
return :(derivative($(toexpr(args[1]; canonicalize=canonicalize)),$(toexpr(op.x; canonicalize=canonicalize))))
elseif op isa Sym
isempty(args) && return nameof(op)
return Expr(:call, toexpr(op), toexpr.(args)...)
return Expr(:call, toexpr(op; canonicalize=canonicalize), toexpr(args; canonicalize=canonicalize)...)
end
return Expr(:call, op, toexpr.(args)...)
return Expr(:call, op, toexpr(args; canonicalize=canonicalize)...)
end
toexpr(s::Sym) = nameof(s)
toexpr(s::Sym; kw...) = nameof(s)

"""
canonicalexpr(O) -> (canonical::Bool, expr)
Expand All @@ -242,18 +252,23 @@ function canonicalexpr(O)
args = arguments(O)
if op === (^)
if length(args) == 2 && args[2] isa Number && args[2] < 0
expr = Expr(:call, ^, Expr(:call, inv, toexpr(args[1])), -args[2])
ex = toexpr(args[1])
if args[2] == -1
expr = Expr(:call, inv, ex)
else
expr = Expr(:call, ^, Expr(:call, inv, ex), -args[2])
end
return true, expr
end
end
return false, O
end

function toexpr(eq::Equation)
Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs))
function toexpr(eq::Equation; kw...)
Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...))
end

toexpr(eq::AbstractArray) = toexpr.(eq)
toexpr(x::Integer) = x
toexpr(x::AbstractFloat) = x
toexpr(x::Num) = toexpr(value(x))
toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs)
toexpr(x::Integer; kw...) = x
toexpr(x::AbstractFloat; kw...) = x
toexpr(x::Num; kw...) = toexpr(value(x); kw...)
4 changes: 2 additions & 2 deletions src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...)
# that latexify can deal with

rhs = getfield.(eqs, :rhs)
rhs = prettify_expr.(toexpr.(rhs))
rhs = prettify_expr.(toexpr(rhs; canonicalize=false))
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs]

lhs = getfield.(eqs, :lhs)
lhs = prettify_expr.(toexpr.(lhs))
lhs = prettify_expr.(toexpr(lhs; canonicalize=false))
lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs]
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs]
Expand Down
4 changes: 2 additions & 2 deletions test/latexify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ eqs = [D(x) ~ σ*(y-x)*D(x-y)/D(z),
# Latexify.@generate_test latexify(eqs)
@test latexify(eqs) == replace(
raw"\begin{align}
\frac{dx(t)}{dt} =& \frac{d\left(x\left( t \right) -1 \cdot y\left( t \right)\right)}{dt} \left( \mathrm{inv}\left( \frac{dz(t)}{dt} \right) \right)^{1} \sigma \left( y\left( t \right) -1 x\left( t \right) \right) \\
\frac{dx(t)}{dt} =& \frac{d\left(x\left( t \right) -1 \cdot y\left( t \right)\right)}{dt} \left( \frac{dz(t)}{dt} \right)^{-1} \sigma \left( y\left( t \right) -1 x\left( t \right) \right) \\
0 =& -1 y\left( t \right) + 0.1 x\left( t \right) \sigma \left( -1 z\left( t \right) + \rho \right) \\
\frac{dz(t)}{dt} =& x\left( t \right) \left( y\left( t \right) \right)^{\frac{2}{3}} -1 z\left( t \right) \beta
\end{align}
Expand Down Expand Up @@ -71,6 +71,6 @@ eqs = [D(x) ~ (1+cos(t))/(1+2*x)]

@test latexify(eqs) == replace(
raw"\begin{align}
\frac{dx(t)}{dt} =& \left( 1 + \cos\left( t \right) \right) \left( \mathrm{inv}\left( 1 + 2 x\left( t \right) \right) \right)^{1}
\frac{dx(t)}{dt} =& \left( 1 + \cos\left( t \right) \right) \left( 1 + 2 x\left( t \right) \right)^{-1}
\end{align}
", "\r\n"=>"\n")

0 comments on commit 4322277

Please sign in to comment.