Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

constructor-level simplification #154

Merged
merged 29 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
313bc7f
begin constructor-level simplification
shashi Jan 4, 2021
d1d7da3
fixes
shashi Jan 4, 2021
c99d25a
div
shashi Jan 4, 2021
e23eba9
whitespace
shashi Jan 4, 2021
a8ee55b
Add `Number - SN` overload and eagerly evaluate coeff when adding a n…
YingboMa Jan 4, 2021
febf9fe
enable tree interface on the fast terms
shashi Jan 5, 2021
25e735f
Fix stuff
shashi Jan 5, 2021
b7928ec
Bring back better printing
shashi Jan 5, 2021
fe5a5f6
fixes 3
shashi Jan 5, 2021
f3b67b5
Merge remote-tracking branch 'origin/s/fast-terms' into s/fast-terms
shashi Jan 5, 2021
e22db85
fix tests and fix printing
shashi Jan 5, 2021
7f7e04f
Delete some more printing code
shashi Jan 5, 2021
f442971
fuzz: print problem
shashi Jan 5, 2021
acdf077
fix printing with Rational and Complex
shashi Jan 5, 2021
6ed0e9d
Cache sorted arguments in Add and Mul
shashi Jan 6, 2021
0f89ab1
fix arguments on Mul copy-paste
shashi Jan 6, 2021
7278252
fix (a+b)-a
shashi Jan 6, 2021
2ec353e
updates for MTK
shashi Jan 6, 2021
f621220
add 1-arg *
shashi Jan 6, 2021
20c63c4
show function of Term so that Differential(t) is visible
shashi Jan 6, 2021
acaf8e3
Fix overload ambiguity
YingboMa Jan 6, 2021
35a0bfa
print fix
shashi Jan 6, 2021
dafe2e6
Fix and test `-(::Add)`
YingboMa Jan 6, 2021
e5cb78d
configurable similarterm in Walk
shashi Jan 8, 2021
4247bd1
proper type promotion for Add
shashi Jan 9, 2021
277d6ea
proper type promotion for Mul and Pow
shashi Jan 9, 2021
e2a8e4c
fix up promotion and tests
shashi Jan 9, 2021
e2517f5
Move Add Mul Pow into types.jl
shashi Jan 9, 2021
c62beb7
add some docs
shashi Jan 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ export @syms, term, showraw
# Sym, Term and other types
include("types.jl")

# Add, Mul and Pow
include("fast-terms.jl")

# Methods on symbolic objects
using SpecialFunctions, NaNMath
export cond
Expand Down
308 changes: 308 additions & 0 deletions src/fast-terms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
import Base: +, -, *, /, \, ^

const SN = Symbolic{<:Number}
"""
Add(coeff, dict)

Represents coeff + (key1 * val1) + (key2 * val2) + ...

where coeff and the vals are non-symbolic numbers.

"""
struct Add{X, T<:Number, D} <: Symbolic{X}
coeff::T
dict::D
end

function Add(coeff, dict)
if isempty(dict)
return coeff
elseif _iszero(coeff) && length(dict) == 1
k,v = first(dict)
return _isone(v) ? k : makemul(1, v, k)
end
Add{Number, typeof(coeff), typeof(dict)}(coeff,dict)
end

symtype(a::Add{X}) where {X} = X


istree(a::Add) = true

operation(a::Add) = +

function arguments(a::Add)
args = sort!([v*k for (k,v) in a.dict], lt=<ₑ)
iszero(a.coeff) ? args : vcat(a.coeff, args)
end

Base.hash(a::Add, u::UInt64) = hash(a.coeff, hash(a.dict, u))

Base.isequal(a::Add, b::Add) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)

function Base.show(io::IO, a::Add)
print_coeff = !iszero(a.coeff)
print_coeff && print(io, a.coeff)

for (i, (k, v)) in enumerate(a.dict)
if (i == 1 && print_coeff) || i != 1
print(io, " + ")
end
if isone(v)
print(io, k)
else
print(io, v, k)
end
end
end

"""
makeadd(sign, coeff::Number, xs...)

Any Muls inside an Add should always have a coeff of 1
and the key (in Add) should instead be used to store the actual coefficient
"""
function makeadd(sign, coeff, xs...)
d = Dict{Any, Number}()
for x in xs
if x isa Number
coeff += x
continue
end
if x isa Mul
k = Mul(1, x.dict)
v = sign * x.coeff + get(d, k, 0)
else
k = x
v = sign + get(d, x, 0)
end
if iszero(v)
delete!(d, k)
else
d[k] = v
end
end
Add(coeff, d)
end

function +(a::SN, b::SN)
if a isa Add
c = makeadd(1, 0, b)
return c isa Add ? a + c : Add(a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1)))
elseif b isa Add
return b + a
end
makeadd(1, 0, a, b)
end

+(a::Number, b::SN) = makeadd(1, a, b)

+(a::SN, b::Number) = makeadd(1, b, a)

+(a::SN) = a

+(a::Add, b::Add) = Add(a.coeff + b.coeff, _merge(+, a.dict, b.dict, filter=_iszero))

+(a::Number, b::Add) = iszero(a) ? b : Add(a + b.coeff, b.dict)

+(b::Add, a::Number) = iszero(a) ? b : Add(a + b.coeff, b.dict)

-(a::Add) = Add(-a.coeff, mapvalues(-, a.dict))

-(a::SN) = makeadd(-1, 0, a)

-(a::Add, b::Add) = Add(a.coeff - b.coeff, _merge(-, a.dict, b.dict, filter=_iszero))

-(a::SN, b::SN) = a + (-b)

-(a::Number, b::SN) = a + (-b)

-(a::SN, b::Number) = a + (-b)

"""
Mul(coeff, dict)

Represents coeff * (key1 ^ val1) * (key2 ^ val2) * ....

where coeff is a non-symbolic number.
"""
struct Mul{X, T<:Number, D} <: Symbolic{X}
coeff::T
dict::D
end

function Mul(a,b)
isempty(b) && return a
if _isone(a) && length(b) == 1
pair = first(b)
if _isone(last(pair)) # first value
return first(pair)
else
return Pow(first(pair), last(pair))
end
else
Mul{Number, typeof(a), typeof(b)}(a,b)
end
end

symtype(a::Mul{X}) where {X} = X

istree(a::Mul) = true

operation(a::Mul) = *

function arguments(a::Mul)
args = sort!([k^v for (k,v) in a.dict], lt=<ₑ)
isone(a.coeff) ? args : vcat(a.coeff, args)
end

Base.hash(m::Mul, u::UInt64) = hash(m.coeff, hash(m.dict, u))

Base.isequal(a::Mul, b::Mul) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)

function Base.show(io::IO, a::Mul)
print_coeff = !isone(a.coeff)
print_coeff && print(io, a.coeff)

for (i, v) in enumerate(arguments(a))
i == 1 && continue
i > 2 && print(io, "*")
print(io, v)
end
end

"""
makemul(xs...)
"""
function makemul(sign, coeff, xs...; d=Dict{Any, Number}())
for x in xs
if x isa Pow && x.exp isa Number
d[x.base] = sign * x.exp + get(d, x.base, 0)
elseif x isa Mul
coeff *= x.coeff
dict = isone(sign) ? x.dict : mapvalues((_,v)->sign*v, x.dict)
d = _merge(+, d, dict, filter=_iszero)
else
k = x
v = sign + get(d, x, 0)
if _iszero(v)
delete!(d, k)
else
d[k] = v
end
end
end
Mul(coeff, d)
end

*(a::SN, b::SN) = makemul(1, 1, a, b)

*(a::Mul, b::Mul) = Mul(a.coeff * b.coeff, _merge(+, a.dict, b.dict, filter=_iszero))

*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : makemul(1,a, b)

*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : makemul(1,a, b)

function /(a::Union{SN,Number}, b::SN)
a * makemul(-1, 1, b)
end

\(a::SN, b::Union{Number, SN}) = b / a

\(a::Number, b::SN) = b / a

/(a::SN, b::Number) = inv(b) * a

"""
Pow(base, exp)

Represents base^exp, a lighter version of Mul(1, Dict(base=>exp))
"""
struct Pow{X, B, E} <: Symbolic{X}
base::B
exp::E
end

function Pow(a,b)
_iszero(b) && return 1
_isone(b) && return a
Pow{Number, typeof(a), typeof(b)}(a,b)
end

symtype(a::Pow{X}) where {X} = X

istree(a::Pow) = true

operation(a::Pow) = ^

arguments(a::Pow) = [a.base, a.exp]

Base.hash(p::Pow, u::UInt) = hash(p.exp, hash(p.base, u))

Base.isequal(p::Pow, b::Pow) = isequal(p.base, b.base) && isequal(p.exp, b.exp)

function Base.show(io::IO, p::Pow)
k, v = p.base, p.exp
if !(k isa Sym)
print(io, "(", k, ")^", v)
else
print(io, k, "^", v)
end
end

^(a::SN, b) = Pow(a, b)

^(a::SN, b::SN) = Pow(a, b)

^(a::Number, b::SN) = Pow(a, b)

function ^(a::Mul, b::Number)
Mul(a.coeff ^ b, mapvalues((k, v) -> b*v, a.dict))
end

function *(a::Mul, b::Pow)
if b.exp isa Number
Mul(a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=_iszero))
else
Mul(a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1), filter=_iszero))
end
end

*(a::Pow, b::Mul) = b * a

function _merge(f, d, others...; filter=x->false)
acc = copy(d)
for other in others
for (k, v) in other
if haskey(acc, k)
v = f(acc[k], v)
end
if filter(v)
delete!(acc, k)
else
acc[k] = v
end
end
end
acc
end

function mapvalues(f, d1::Dict)
d = copy(d1)
for (k, v) in d
d[k] = f(k, v)
end
d
end

function similarterm(p::Union{Mul, Add, Pow}, f, args)
if f === (+)
makeadd(1, 0, args...)
elseif f == (*)
makemul(1, 1, args...)
elseif f == (^)
Pow(args...)
else
f(args...)
end
end
31 changes: 10 additions & 21 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
besselj1, bessely0, bessely1, besselj, bessely, besseli,
besselk, hankelh1, hankelh2, polygamma, beta, logbeta

const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
const monadic = [deg2rad, rad2deg, transpose, conj, asind, log1p, acsch,
acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh,
abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, acotd,
asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10,
Expand All @@ -14,7 +14,7 @@ const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
airybiprime, besselj0, besselj1, bessely0, bessely1]

const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign,
const diadic = [max, min, hypot, atan, mod, rem, copysign,
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
polygamma, beta, logbeta]

Expand Down Expand Up @@ -71,12 +71,20 @@ end

@number_methods(Sym, term(f, a), term(f, a, b))
@number_methods(Term, term(f, a), term(f, a, b))
@number_methods(Symbolic{<:Number}, term(f, a), term(f, a, b))

for f in diadic
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
end

for f in [+, *, \, /, ^]
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
end

promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
Base.rem2pi(x::Symbolic, mode::Base.RoundingMode) = term(rem2pi, x, mode)

Expand All @@ -93,25 +101,6 @@ rec_promote_symtype(f, x) = promote_symtype(f, x)
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)

# Variadic methods
for f in [+, *]

@eval (::$(typeof(f)))(x::Symbolic) = x

# single arg
@eval function (::$(typeof(f)))(x::Symbolic, w::Number...)
term($f, x,w...,
type=rec_promote_symtype($f, map(symtype, (x,w...))...))
end
@eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
end

Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ function (acr::ACRule)(term)
itr = acr.sets(eachindex(args), acr.arity)

for inds in itr
result = r(similarterm(term, f, @views args[inds]))
result = r(Term{T}(f, @views args[inds]))
if !isnothing(result)
# Assumption: inds are unique
length(args) == length(inds) && return result
Expand Down
Loading