Skip to content

Commit

Permalink
Change all unsorted_arguments to arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenszhu committed Jun 22, 2024
1 parent 4e26a2d commit 590083e
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ end
function cse_state!(state, t)
!iscall(t) && return t
state[t] = Base.get(state, t, 0) + 1
foreach(x->cse_state!(state, x), unsorted_arguments(t))
foreach(x->cse_state!(state, x), arguments(t))
end

function cse_block!(assignments, counter, names, name, state, x)
Expand All @@ -759,7 +759,7 @@ function cse_block!(assignments, counter, names, name, state, x)
return sym
end
elseif iscall(x)
args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x))
args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
if isterm(x)
return term(operation(x), args...)
else
Expand Down
10 changes: 5 additions & 5 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ is the function being called.
function operation end

"""
arguments(x)
sorted_arguments(x)
Get the arguments of `x`, must be defined if `iscall(x)` is `true`.
"""
function arguments end
function sorted_arguments end

"""
unsorted_arguments(x::T)
sorted_arguments(x::T)
If x is a term satisfying `iscall(x)` and your term type `T` provides
an optimized implementation for storing the arguments, this function can
be used to retrieve the arguments when the order of arguments does not matter
but the speed of the operation does.
"""
unsorted_arguments(x) = arguments(x)
arity(x) = length(unsorted_arguments(x))
function arguments end
arity(x) = length(arguments(x))

"""
metadata(x)
Expand Down
2 changes: 1 addition & 1 deletion src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function lexlt(degs1, degs2)
return false # they are equal
end

_arglen(a) = iscall(a) ? length(unsorted_arguments(a)) : 0
_arglen(a) = iscall(a) ? length(arguments(a)) : 0

function <(a::Tuple, b::Tuple)
for (x, y) in zip(a, b)
Expand Down
12 changes: 6 additions & 6 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ end

function add_with_div(x, flatten=true)
(!iscall(x) || operation(x) != (+)) && return x
aa = unsorted_arguments(x)
aa = arguments(x)
!any(a->isdiv(a), aa) && return x # no rewrite necessary

divs = filter(a->isdiv(a), aa)
Expand Down Expand Up @@ -385,12 +385,12 @@ end

function needs_div_rules(x)
(isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) ||
(iscall(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) ||
(iscall(x) && any(needs_div_rules, unsorted_arguments(x)))
(iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) ||
(iscall(x) && any(needs_div_rules, arguments(x)))
end

function has_div(x)
return isdiv(x) || (iscall(x) && any(has_div, unsorted_arguments(x)))
return isdiv(x) || (iscall(x) && any(has_div, arguments(x)))
end

flatten_pows(xs) = map(xs) do x
Expand Down Expand Up @@ -418,8 +418,8 @@ Has optimized processes for `Mul` and `Pow` terms.
function quick_cancel(d)
if ispow(d) && isdiv(d.base)
return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp))
elseif ismul(d) && any(isdiv, unsorted_arguments(d))
return prod(unsorted_arguments(d))
elseif ismul(d) && any(isdiv, arguments(d))
return prod(arguments(d))
elseif isdiv(d)
num, den = quick_cancel(d.num, d.den)
return Div(num, den)
Expand Down
2 changes: 1 addition & 1 deletion src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}

if iscall(x)
x = p.maketerm(x, operation(x), map(PassThrough(p),
unsorted_arguments(x)), metadata=metadata(x))
arguments(x)), metadata=metadata(x))
end

return ord === :post ? p.rw(x) : x
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ function (acr::ACRule)(term)
end

T = symtype(term)
args = unsorted_arguments(term)
args = arguments(term)

itr = acr.sets(eachindex(args), acr.arity)

Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ end

has_operation(x, op) = (iscall(x) && (operation(x) == op ||
any(a->has_operation(a, op),
unsorted_arguments(x))))
arguments(x))))

Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)
4 changes: 2 additions & 2 deletions src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ function substitute(expr, dict; fold=true)
op = substitute(operation(expr), dict; fold=fold)
if fold
canfold = !(op isa Symbolic)
args = map(unsorted_arguments(expr)) do x
args = map(arguments(expr)) do x
x′ = substitute(x, dict; fold=fold)
canfold = canfold && !(x′ isa Symbolic)
x′
end
canfold && return op(args...)
args
else
args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr))
args = map(x->substitute(x, dict, fold=fold), arguments(expr))
end

maketerm(typeof(expr),
Expand Down

0 comments on commit 590083e

Please sign in to comment.