Skip to content

Commit

Permalink
Merge c619ea9 into d009ca8
Browse files Browse the repository at this point in the history
  • Loading branch information
0x0f0f0f authored Jun 26, 2024
2 parents d009ca8 + c619ea9 commit 79dbef7
Show file tree
Hide file tree
Showing 17 changed files with 132 additions and 215 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
TermInterface = "0.4"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
julia = "1.3"
Expand Down
16 changes: 15 additions & 1 deletion docs/src/manual/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym

## SymbolicUtils.jl only methods

`promote_symtype(f, arg_symtypes...)`
### `symtype(x)`

Returns the
[numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types)
of `x`. By default this is just `typeof(x)`.
Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules
specific to numbers (such as commutativity of multiplication). Or such
rules that may be implemented in the future.

### `issym(x)`

Returns `true` if `x` is a `Sym`. If `true`, `nameof` must be defined
on `x` and must return a `Symbol`.

### `promote_symtype(f, arg_symtypes...)`

Returns the appropriate output type of applying `f` on arguments of type `arg_symtypes`.
2 changes: 1 addition & 1 deletion docs/src/manual/representation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Performance of symbolic simplification depends on the datastructures used to rep

The most basic term representation simply holds a function call and stores the function and the arguments it is called with. This is done by the `Term` type in SymbolicUtils. Functions that aren't commutative or associative, such as `sin` or `hypot` are stored as `Term`s. Commutative and associative operations like `+`, `*`, and their supporting operations like `-`, `/` and `^`, when used on terms of type `<:Number`, stand to gain from the use of more efficient datastrucutres.

All term representations must support `operation` and `arguments` functions. And they must define `istree` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.)
All term representations must support `operation` and `arguments` functions. And they must define `iscall` and `isexpr` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.)


### Preliminary representation of arithmetic
Expand Down
4 changes: 2 additions & 2 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ using SymbolicIndexingInterface
import Base: +, -, *, /, //, \, ^, ImmutableDict
using ConstructionBase
using TermInterface
import TermInterface: iscall, isexpr, issym, symtype, head, children,
operation, arguments, metadata, maketerm
import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments

const istree = iscall
Base.@deprecate_binding istree iscall
Expand Down
20 changes: 9 additions & 11 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
symtype, similarterm, sorted_arguments, metadata, isterm, term, maketerm
symtype, sorted_arguments, metadata, isterm, term, maketerm

##== state management ==##

Expand Down Expand Up @@ -115,7 +115,7 @@ function function_to_expr(op, O, st)
(get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing
name = nameof(op)
fun = GlobalRef(NaNMath, name)
args = map(Base.Fix2(toexpr, st), arguments(O))
args = map(Base.Fix2(toexpr, st), sorted_arguments(O))
expr = Expr(:call, fun)
append!(expr.args, args)
return expr
Expand All @@ -138,7 +138,7 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
end

function function_to_expr(::typeof(^), O, st)
args = arguments(O)
args = sorted_arguments(O)
if length(args) == 2 && args[2] isa Real && args[2] < 0
ex = args[1]
if args[2] == -1
Expand All @@ -151,7 +151,7 @@ function function_to_expr(::typeof(^), O, st)
end

function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
args = arguments(O)
args = sorted_arguments(O)
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
end

Expand Down Expand Up @@ -183,7 +183,7 @@ function toexpr(O, st)
return expr′
else
!iscall(O) && return O
args = arguments(O)
args = sorted_arguments(O)
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
end
end
Expand Down Expand Up @@ -693,8 +693,8 @@ end
function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
args = map(Base.Fix1(_cse!, mem), arguments(expr))
t = similarterm(expr, op, args)
args = map(Base.Fix1(_cse!, mem), sorted_arguments(expr))
t = maketerm(typeof(expr), op, args, nothing)

v, dict = mem
update! = let v=v, t=t
Expand All @@ -716,7 +716,7 @@ end

function _cse(exprs::AbstractArray)
letblock = cse(Term{Any}(tuple, vec(exprs)))
letblock.pairs, reshape(arguments(letblock.body), size(exprs))
letblock.pairs, reshape(sorted_arguments(letblock.body), size(exprs))
end

function cse(x::MakeArray)
Expand Down Expand Up @@ -763,9 +763,7 @@ function cse_block!(assignments, counter, names, name, state, x)
if isterm(x)
return term(operation(x), args...)
else
return maketerm(typeof(x), operation(x),
args, symtype(x),
metadata(x))
return maketerm(typeof(x), operation(x), args, metadata(x))
end
else
return x
Expand Down
84 changes: 0 additions & 84 deletions src/interface.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function matcher(segment::Segment)
end

function term_matcher(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
matchers = (matcher(operation(term)), map(matcher, sorted_arguments(term))...,)
function term_matcher(success, data, bindings)

!islist(data) && return nothing
Expand Down
2 changes: 1 addition & 1 deletion src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function <ₑ(a::BasicSymbolic, b::BasicSymbolic)
bw = monomial_lt(db, da)
if fw === bw && !isequal(a, b)
if _arglen(a) == _arglen(b)
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
return (operation(a), sorted_arguments(a)...,) <ₑ (operation(b), sorted_arguments(b)...,)
else
return _arglen(a) < _arglen(b)
end
Expand Down
27 changes: 12 additions & 15 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
end

op = operation(x)
args = arguments(x)
args = sorted_arguments(x)

local_polyize(y) = polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse)

Expand All @@ -121,7 +121,6 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
maketerm(typeof(x),
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args),
symtype(x),
metadata(x))
else
x
Expand Down Expand Up @@ -176,18 +175,18 @@ isexpr(x::PolyForm) = true
iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true

function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata)
basicsymbolic(t, f, args, symtype, metadata)
function maketerm(t::Type{<:PolyForm}, f, args, metadata)
# TODO: this looks uncovered.
basicsymbolic(f, args, nothing, metadata)
end
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)},
args, symtype, metadata)
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata)
f(args...)
end

head(::PolyForm) = PolyForm
operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+)

function arguments(x::PolyForm{T}) where {T}
function TermInterface.arguments(x::PolyForm{T}) where {T}

function is_var(v)
MP.nterms(v) == 1 &&
Expand Down Expand Up @@ -231,10 +230,7 @@ function arguments(x::PolyForm{T}) where {T}
PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts]
end
end

sorted_arguments(x::PolyForm) = arguments(x)

children(x::PolyForm) = [operation(x); arguments(x)]
children(x::PolyForm) = arguments(x)

Base.show(io::IO, x::PolyForm) = show_term(io, x)

Expand All @@ -255,7 +251,7 @@ function unpolyize(x)
# we need a special makterm here because the default one used in Postwalk will call
# promote_symtype to get the new type, but we just want to forward that in case
# promote_symtype is not defined for some of the expressions here.
Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x)
Postwalk(identity, maketerm=(T,f,args,m) -> maketerm(T, f, args, m))(x)
end

function toterm(x::PolyForm)
Expand Down Expand Up @@ -307,7 +303,8 @@ function add_divs(x, y)
end
end

function frac_maketerm(T, f, args, stype, metadata)
function frac_maketerm(T, f, args, metadata)
# TODO add stype to T?
if f in (*, /, \, +, -)
f(args...)
elseif f == (^)
Expand All @@ -317,7 +314,7 @@ function frac_maketerm(T, f, args, stype, metadata)
args[1]^args[2]
end
else
maketerm(T, f, args, stype, metadata)
maketerm(T, f, args, metadata)
end
end

Expand Down Expand Up @@ -394,7 +391,7 @@ function has_div(x)
end

flatten_pows(xs) = map(xs) do x
ispow(x) ? Iterators.repeated(arguments(x)...) : (x,)
ispow(x) ? Iterators.repeated(sorted_arguments(x)...) : (x,)
end |> Iterators.flatten |> a->collect(Any,a)

coefftype(x::PolyForm) = coefftype(x.p)
Expand Down
38 changes: 11 additions & 27 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,7 @@ end
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function
# that behaves like `similarterm` here, we use `compatmaker` to wrap
# maketerm-like input to do this, with a warning if similarterm provided
# we need this workaround to deprecate because similarterm takes value
# but maketerm only knows the type.
maketerm::F
end

function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded}
Expand All @@ -183,25 +179,13 @@ end

using .Threads

function compatmaker(similarterm, maketerm)
# XXX: delete this and only use maketerm in a future release.
if similarterm isa Nothing
function (x, f, args, type=_promote_symtype(f, args); metadata)
maketerm(typeof(x), f, args, type, metadata)
end
else
Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm)
similarterm
end
end
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)

function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm)
Walk{:post, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm)
end

function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm)
Walk{:pre, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm)
end

struct PassThrough{C}
Expand All @@ -220,8 +204,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
end

if iscall(x)
x = p.maketerm(x, operation(x), map(PassThrough(p),
arguments(x)), metadata=metadata(x))
x = p.maketerm(typeof(x), operation(x), map(PassThrough(p),
arguments(x)), metadata(x))
end

return ord === :post ? p.rw(x) : x
Expand All @@ -237,15 +221,15 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
x = p.rw(x)
end
if iscall(x)
_args = map(arguments(x)) do arg
_args = map(sorted_arguments(x)) do arg
if node_count(arg) > p.thread_cutoff
Threads.@spawn p(arg)
else
p(arg)
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = p.maketerm(x, operation(x), args, metadata=metadata(x))
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, sorted_arguments(x))
t = p.maketerm(typeof(x), operation(x), args, metadata(x))
end
return ord === :post ? p.rw(t) : t
else
Expand Down
Loading

0 comments on commit 79dbef7

Please sign in to comment.