From 432227721765ca98642950ddecddf37e329d80fe Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 11 Jan 2021 12:14:40 -0500 Subject: [PATCH] Add `canonicalize` kw and fix latexify tests --- src/direct.jl | 43 +++++++++++++++++++++++++++-------------- src/latexify_recipes.jl | 4 ++-- test/latexify.jl | 4 ++-- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/direct.jl b/src/direct.jl index dd996787bb..ebc31103d3 100644 --- a/src/direct.jl +++ b/src/direct.jl @@ -215,21 +215,31 @@ function sparsehessian(O, vars::AbstractVector; simplify = true) return H end -function toexpr(O) - canonical, O = canonicalexpr(O) - canonical && return O +""" + toexpr(O::Union{Symbolics,Num,Equation,AbstractArray}; canonicalize=true) -> Expr + +Convert `Symbolics` into `Expr`. If `canonicalize`, then we turn exprs like +`x^(-n)` into `inv(x)^n` to avoid type error when evaluating. +""" +function toexpr(O; canonicalize=true) + if canonicalize + canonical, O = canonicalexpr(O) + canonical && return O + else + !istree(O) && return O + end op = operation(O) args = arguments(O) if op isa Differential - return :(derivative($(toexpr(args[1])),$(toexpr(op.x)))) + return :(derivative($(toexpr(args[1]; canonicalize=canonicalize)),$(toexpr(op.x; canonicalize=canonicalize)))) elseif op isa Sym isempty(args) && return nameof(op) - return Expr(:call, toexpr(op), toexpr.(args)...) + return Expr(:call, toexpr(op; canonicalize=canonicalize), toexpr(args; canonicalize=canonicalize)...) end - return Expr(:call, op, toexpr.(args)...) + return Expr(:call, op, toexpr(args; canonicalize=canonicalize)...) end -toexpr(s::Sym) = nameof(s) +toexpr(s::Sym; kw...) = nameof(s) """ canonicalexpr(O) -> (canonical::Bool, expr) @@ -242,18 +252,23 @@ function canonicalexpr(O) args = arguments(O) if op === (^) if length(args) == 2 && args[2] isa Number && args[2] < 0 - expr = Expr(:call, ^, Expr(:call, inv, toexpr(args[1])), -args[2]) + ex = toexpr(args[1]) + if args[2] == -1 + expr = Expr(:call, inv, ex) + else + expr = Expr(:call, ^, Expr(:call, inv, ex), -args[2]) + end return true, expr end end return false, O end -function toexpr(eq::Equation) - Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs)) +function toexpr(eq::Equation; kw...) + Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...)) end -toexpr(eq::AbstractArray) = toexpr.(eq) -toexpr(x::Integer) = x -toexpr(x::AbstractFloat) = x -toexpr(x::Num) = toexpr(value(x)) +toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs) +toexpr(x::Integer; kw...) = x +toexpr(x::AbstractFloat; kw...) = x +toexpr(x::Num; kw...) = toexpr(value(x); kw...) diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 463c2836e5..6aefdf09ad 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...) # that latexify can deal with rhs = getfield.(eqs, :rhs) - rhs = prettify_expr.(toexpr.(rhs)) + rhs = prettify_expr.(toexpr(rhs; canonicalize=false)) rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs] rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs] rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs] lhs = getfield.(eqs, :lhs) - lhs = prettify_expr.(toexpr.(lhs)) + lhs = prettify_expr.(toexpr(lhs; canonicalize=false)) lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs] lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs] lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs] diff --git a/test/latexify.jl b/test/latexify.jl index c6fbbbb51e..307a3f89b4 100644 --- a/test/latexify.jl +++ b/test/latexify.jl @@ -30,7 +30,7 @@ eqs = [D(x) ~ σ*(y-x)*D(x-y)/D(z), # Latexify.@generate_test latexify(eqs) @test latexify(eqs) == replace( raw"\begin{align} -\frac{dx(t)}{dt} =& \frac{d\left(x\left( t \right) -1 \cdot y\left( t \right)\right)}{dt} \left( \mathrm{inv}\left( \frac{dz(t)}{dt} \right) \right)^{1} \sigma \left( y\left( t \right) -1 x\left( t \right) \right) \\ +\frac{dx(t)}{dt} =& \frac{d\left(x\left( t \right) -1 \cdot y\left( t \right)\right)}{dt} \left( \frac{dz(t)}{dt} \right)^{-1} \sigma \left( y\left( t \right) -1 x\left( t \right) \right) \\ 0 =& -1 y\left( t \right) + 0.1 x\left( t \right) \sigma \left( -1 z\left( t \right) + \rho \right) \\ \frac{dz(t)}{dt} =& x\left( t \right) \left( y\left( t \right) \right)^{\frac{2}{3}} -1 z\left( t \right) \beta \end{align} @@ -71,6 +71,6 @@ eqs = [D(x) ~ (1+cos(t))/(1+2*x)] @test latexify(eqs) == replace( raw"\begin{align} -\frac{dx(t)}{dt} =& \left( 1 + \cos\left( t \right) \right) \left( \mathrm{inv}\left( 1 + 2 x\left( t \right) \right) \right)^{1} +\frac{dx(t)}{dt} =& \left( 1 + \cos\left( t \right) \right) \left( 1 + 2 x\left( t \right) \right)^{-1} \end{align} ", "\r\n"=>"\n")