Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

constructor-level simplification #154

Merged
merged 29 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
313bc7f
begin constructor-level simplification
shashi Jan 4, 2021
d1d7da3
fixes
shashi Jan 4, 2021
c99d25a
div
shashi Jan 4, 2021
e23eba9
whitespace
shashi Jan 4, 2021
a8ee55b
Add `Number - SN` overload and eagerly evaluate coeff when adding a n…
YingboMa Jan 4, 2021
febf9fe
enable tree interface on the fast terms
shashi Jan 5, 2021
25e735f
Fix stuff
shashi Jan 5, 2021
b7928ec
Bring back better printing
shashi Jan 5, 2021
fe5a5f6
fixes 3
shashi Jan 5, 2021
f3b67b5
Merge remote-tracking branch 'origin/s/fast-terms' into s/fast-terms
shashi Jan 5, 2021
e22db85
fix tests and fix printing
shashi Jan 5, 2021
7f7e04f
Delete some more printing code
shashi Jan 5, 2021
f442971
fuzz: print problem
shashi Jan 5, 2021
acdf077
fix printing with Rational and Complex
shashi Jan 5, 2021
6ed0e9d
Cache sorted arguments in Add and Mul
shashi Jan 6, 2021
0f89ab1
fix arguments on Mul copy-paste
shashi Jan 6, 2021
7278252
fix (a+b)-a
shashi Jan 6, 2021
2ec353e
updates for MTK
shashi Jan 6, 2021
f621220
add 1-arg *
shashi Jan 6, 2021
20c63c4
show function of Term so that Differential(t) is visible
shashi Jan 6, 2021
acaf8e3
Fix overload ambiguity
YingboMa Jan 6, 2021
35a0bfa
print fix
shashi Jan 6, 2021
dafe2e6
Fix and test `-(::Add)`
YingboMa Jan 6, 2021
e5cb78d
configurable similarterm in Walk
shashi Jan 8, 2021
4247bd1
proper type promotion for Add
shashi Jan 9, 2021
277d6ea
proper type promotion for Mul and Pow
shashi Jan 9, 2021
e2a8e4c
fix up promotion and tests
shashi Jan 9, 2021
e2517f5
Move Add Mul Pow into types.jl
shashi Jan 9, 2021
c62beb7
add some docs
shashi Jan 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ using SymbolicUtils # hide

{{doc Term Term type}}

{{doc Add Add type}}

{{doc Mul Mul type}}

{{doc Pow Pow type}}

{{doc promote_symtype promote_symtype fn}}

## Interfacing
Expand Down
37 changes: 21 additions & 16 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ where appropriate -->

The main features are:

- Symbols (`Sym`s) carry type information. ([read more](#symbolic_expressions))
- Compound expressions composed of `Sym`s propagate type information. ([read more](#symbolic_expressions))
- A flexible [rule-based rewriting language](#rule-based_rewriting) allowing liberal use of user defined matchers and rewriters.
- Fast expressions
- A [combinator library](#composing-rewriters) for making rewriters.
- A [rule-based rewriting language](#rule-based_rewriting).
- Type promotion:
- Symbols (`Sym`s) carry type information. ([read more](#symbolic_expressions))
- Compound expressions composed of `Sym`s propagate type information. ([read more](#symbolic_expressions))
- Set of [simplification rules](#simplification). These can be remixed and extended for special purposes.


## Table of contents

\tableofcontents <!-- you can use \toc as well -->

## Symbolic expressions
## `Sym`s

First, let's use the `@syms` macro to create a few symbols.

Expand Down Expand Up @@ -66,17 +68,6 @@ expr1 + expr2
```
\out{expr}

### Simplified printing

Tip: you can set `SymbolicUtils.show_simplified[] = true` to enable simplification on printing, or call `SymbolicUtils.showraw(expr)` to display an expression without simplification.
In the REPL, if an expression was successfully simplified before printing, it will appear in yellow rather than white, as a visual cue that what you are looking at is not the exact datastructure.

```julia:showraw
using SymbolicUtils: showraw

showraw(expr1 + expr2)
```
\out{showraw}

**Function-like symbols**

Expand Down Expand Up @@ -106,6 +97,20 @@ g(2//5, g(1, β))

This works because `g` "returns" a `Real`.


## Expression interface

Symbolic expressions are of type `Term{T}`, `Add{T}`, `Mul{T}` or `Pow{T}` and denote some function call where one or more arguments are themselves such expressions or `Sym`s.

All the expression types support the following:

- `istree(x)` -- always returns `true` denoting, `x` is not a leaf node like Sym or a literal.
- `operation(x)` -- the function being called
- `arguments(x)` -- a vector of arguments
- `symtype(x)` -- the "inferred" type (`T`)

See more on the interface [here](/interface)

## Rule-based rewriting

Rewrite rules match and transform an expression. A rule is written using either the `@rule` macro or the `@acrule` macro.
Expand Down Expand Up @@ -151,7 +156,7 @@ Notice that there is a subexpression `(2 * w) + (2 * w)` that could be simplifie

### Predicates for matching

Matcher pattern may contain slot variables with attached predicates, written as `~x::f` where `f` is a function that takes a matched expression (a `Term` object a `Sym` or any Julia value that is in the expression tree) and returns a boolean value. Such a slot will be considered a match only if `f` returns true.
Matcher pattern may contain slot variables with attached predicates, written as `~x::f` where `f` is a function that takes a matched expression and returns a boolean value. Such a slot will be considered a match only if `f` returns true.

Similarly `~~x::g` is a way of attaching a predicate `g` to a segment variable. In the case of segment variables `g` gets a vector of 0 or more expressions and must return a boolean value. If the same slot or segment variable appears twice in the matcher pattern, then at most one of the occurance should have a predicate.

Expand Down
6 changes: 4 additions & 2 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ module SymbolicUtils

export @syms, term, showraw

# Sym, Term and other types
# Sym, Term,
# Add, Mul and Pow
using DataStructures
import Base: +, -, *, /, \, ^
include("types.jl")

# Methods on symbolic objects
Expand Down Expand Up @@ -32,7 +35,6 @@ include("matchers.jl")
# Convert to an efficient multi-variate polynomial representation
import AbstractAlgebra.Generic: MPoly, PolynomialRing, ZZ, exponent_vector
using AbstractAlgebra: ismonomial, symbols
using DataStructures
include("abstractalgebra.jl")

# Term ordering
Expand Down
55 changes: 26 additions & 29 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
besselj1, bessely0, bessely1, besselj, bessely, besseli,
besselk, hankelh1, hankelh2, polygamma, beta, logbeta

const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
const monadic = [deg2rad, rad2deg, transpose, conj, asind, log1p, acsch,
acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh,
abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, acotd,
asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10,
Expand All @@ -14,10 +14,9 @@ const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
airybiprime, besselj0, besselj1, bessely0, bessely1]

const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign,
const diadic = [max, min, hypot, atan, mod, rem, copysign,
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
polygamma, beta, logbeta]

const previously_declared_for = Set([])

# TODO: it's not possible to dispatch on the symtype! (only problem is Parameter{})
Expand All @@ -32,13 +31,17 @@ end
islike(a, T) = symtype(a) <: T

# TODO: keep domains tighter than this
function number_methods(T, rhs1, rhs2)
function number_methods(T, rhs1, rhs2, options=nothing)
exprs = []

skip_basics = !isnothing(options) ? options == :skipbasics : false
basic_monadic = [-, +]
basic_diadic = [+, -, *, /, \, ^]

rhs2 = :($assert_like(f, Number, a, b); $rhs2)
rhs1 = :($assert_like(f, Number, a); $rhs1)

for f in diadic
for f in (skip_basics ? diadic : vcat(basic_diadic, diadic))
for S in previously_declared_for
push!(exprs, quote
(f::$(typeof(f)))(a::$T, b::$S) = $rhs2
Expand All @@ -58,25 +61,38 @@ function number_methods(T, rhs1, rhs2)
push!(exprs, expr)
end

for f in monadic
for f in (skip_basics ? monadic : vcat(basic_monadic, monadic))
push!(exprs, :((f::$(typeof(f)))(a::$T) = $rhs1))
end
push!(exprs, :(push!($previously_declared_for, $T)))
Expr(:block, exprs...)
end

macro number_methods(T, rhs1, rhs2)
number_methods(T, rhs1, rhs2) |> esc
macro number_methods(T, rhs1, rhs2, options=nothing)
number_methods(T, rhs1, rhs2, options) |> esc
end

@number_methods(Sym, term(f, a), term(f, a, b))
@number_methods(Term, term(f, a), term(f, a, b))
@number_methods(Sym, term(f, a), term(f, a, b), skipbasics)
@number_methods(Term, term(f, a), term(f, a, b), skipbasics)
@number_methods(Add, term(f, a), term(f, a, b), skipbasics)
@number_methods(Mul, term(f, a), term(f, a, b), skipbasics)
@number_methods(Pow, term(f, a), term(f, a, b), skipbasics)

for f in diadic
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
end

for f in [+, -, *, \, /, ^]
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
end
for f in [+, -, *]
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T
end

promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
Base.rem2pi(x::Symbolic, mode::Base.RoundingMode) = term(rem2pi, x, mode)

Expand All @@ -93,25 +109,6 @@ rec_promote_symtype(f, x) = promote_symtype(f, x)
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)

# Variadic methods
for f in [+, *]

@eval (::$(typeof(f)))(x::Symbolic) = x

# single arg
@eval function (::$(typeof(f)))(x::Symbolic, w::Number...)
term($f, x,w...,
type=rec_promote_symtype($f, map(symtype, (x,w...))...))
end
@eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
end

Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
Expand Down
19 changes: 10 additions & 9 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,20 @@ function (rw::Fixpoint)(x)
return x
end

struct Walk{ord, C, threaded}
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
similarterm::F
end

using .Threads

function Postwalk(rw; threaded::Bool=false, thread_cutoff=100)
Walk{:post, typeof(rw), threaded}(rw, thread_cutoff)
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
end

function Prewalk(rw; threaded::Bool=false, thread_cutoff=100)
Walk{:pre, typeof(rw), threaded}(rw, thread_cutoff)
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
end

struct PassThrough{C}
Expand All @@ -128,22 +129,22 @@ end
(p::PassThrough)(x) = (y=p.rw(x); isnothing(y) ? x : y)

passthrough(x, default) = isnothing(x) ? default : x
function (p::Walk{ord, C, false})(x) where {ord, C}
function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
@assert ord === :pre || ord === :post
if istree(x)
if ord === :pre
x = p.rw(x)
end
if istree(x)
x = similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
x = p.similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
end
return ord === :post ? p.rw(x) : x
else
return p.rw(x)
end
end

function (p::Walk{ord, C, true})(x) where {ord, C}
function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
@assert ord === :pre || ord === :post
if istree(x)
if ord === :pre
Expand All @@ -158,7 +159,7 @@ function (p::Walk{ord, C, true})(x) where {ord, C}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = similarterm(x, operation(x), args)
t = p.similarterm(x, operation(x), args)
end
return ord === :post ? p.rw(t) : t
else
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ function (acr::ACRule)(term)
itr = acr.sets(eachindex(args), acr.arity)

for inds in itr
result = r(similarterm(term, f, @views args[inds]))
result = r(Term{T}(f, @views args[inds]))
if !isnothing(result)
# Assumption: inds are unique
length(args) == length(inds) && return result
Expand Down
Loading