Skip to content

Commit

Permalink
Multiplication of 3 or more AlgebraElement (#46)
Browse files Browse the repository at this point in the history
* Multiplication of 3 or more AlgebraElement

* No aggregate_constants

* Add test

* Fix

* Add test
  • Loading branch information
blegat authored Jun 14, 2024
1 parent ba9429b commit 8b26f89
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 44 deletions.
68 changes: 47 additions & 21 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
function _preallocate_output(X::AlgebraElement, a::Number, op)
T = MA.promote_operation(op, eltype(X), typeof(a))
return similar(X, T)
end
_coeff_type(X::AlgebraElement) = eltype(X)
_coeff_type(a) = typeof(a)

function _preallocate_output(X::AlgebraElement, Y::AlgebraElement, op)
T = MA.promote_operation(op, eltype(X), eltype(Y))
if coeffs(Y) isa DenseArray # what a hack :)
return similar(Y, T)
function _preallocate_output(op, args::Vararg{Any,N}) where {N}
T = MA.promote_operation(op, _coeff_type.(args)...)
if args[2] isa AlgebraElement && coeffs(args[2]) isa DenseArray # what a hack :)
return similar(args[2], T)
end
return similar(X, T)
return similar(args[1], T)
end

# module structure:
Expand All @@ -18,23 +16,26 @@ Base.:(/)(X::AlgebraElement, a::Number) = inv(a) * X
Base.:(//)(X::AlgebraElement, a::Number) = 1 // a * X

function Base.:-(X::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, -1, *), -, X)
return MA.operate_to!(_preallocate_output(*, X, -1), -, X)
end
function Base.:*(a::Number, X::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, a, *), *, X, a)
return MA.operate_to!(_preallocate_output(*, X, a), *, X, a)
end
function Base.:div(X::AlgebraElement, a::Number)
return MA.operate_to!(_preallocate_output(X, a, div), div, X, a)
return MA.operate_to!(_preallocate_output(div, X, a), div, X, a)
end

function Base.:+(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, Y, +), +, X, Y)
return MA.operate_to!(_preallocate_output(+, X, Y), +, X, Y)
end
function Base.:-(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, Y, -), -, X, Y)
return MA.operate_to!(_preallocate_output(-, X, Y), -, X, Y)
end
function Base.:*(X::AlgebraElement, Y::AlgebraElement)
return MA.operate_to!(_preallocate_output(X, Y, *), *, X, Y)
return MA.operate_to!(_preallocate_output(*, X, Y), *, X, Y)
end
function Base.:*(args::Vararg{AlgebraElement,N}) where {N}
return MA.operate_to!(_preallocate_output(*, args...), *, args...)
end
Base.:^(a::AlgebraElement, p::Integer) = Base.power_by_squaring(a, p)

Expand Down Expand Up @@ -99,11 +100,36 @@ end
function MA.operate_to!(
res::AlgebraElement,
::typeof(*),
X::AlgebraElement,
Y::AlgebraElement,
)
@assert parent(res) === parent(X) === parent(Y)
mstr = mstructure(basis(parent(res)))
MA.operate_to!(coeffs(res), mstr, coeffs(X), coeffs(Y))
args::Vararg{AlgebraElement,N},
) where {N}
for arg in args
if arg isa AlgebraElement
@assert parent(res) == parent(arg)
end
end
mstr = mstructure(basis(res))
MA.operate_to!(coeffs(res), mstr, coeffs.(args)...)
return res
end

function MA.operate!(
::UnsafeAddMul{typeof(*)},
res::AlgebraElement,
args::Vararg{AlgebraElement,N},
) where {N}
for arg in args
if arg isa AlgebraElement
@assert parent(res) == parent(arg)
end
end
mstr = mstructure(basis(res))
MA.operate!(UnsafeAddMul(mstr), coeffs(res), coeffs.(args)...)
return res
end

# TODO just push to internal vectors once canonical `does` not just
# call `dropzeros!` but also reorders
function unsafe_push!(a::SparseArrays.SparseVector, k, v)
a[k] = MA.add!!(a[k], v)
return a
end
3 changes: 1 addition & 2 deletions src/diracs_augmented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ function coeffs!(
MA.operate!(
UnsafeAddMul(*),
res,
v,
SparseCoefficients((target[Augmented(x)],), (1,)),
SparseCoefficients((target[Augmented(x)],), (v,)),
)
end
MA.operate!(canonical, res)
Expand Down
39 changes: 18 additions & 21 deletions src/mstructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,34 @@ struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}}
structure::M
end

function MA.operate_to!(res, ms::MultiplicativeStructure, v, w)
if res === v || res === w
function MA.operate_to!(res, ms::MultiplicativeStructure, args::Vararg{Any,N}) where {N}
if any(Base.Fix1(===, res), args)
throw(ArgumentError("No alias allowed"))
end
MA.operate!(zero, res)
MA.operate!(UnsafeAddMul(ms), res, v, w)
MA.operate!(UnsafeAddMul(ms), res, args...)
MA.operate!(canonical, res)
return res
end

function MA.operate!(
::UnsafeAddMul{typeof(*)},
mc::SparseCoefficients,
val,
c::AbstractCoefficients,
)
append!(mc.basis_elements, keys(c))
vals = values(c)
if vals isa AbstractVector
append!(mc.values, val .* vals)
else
append!(mc.values, val * collect(values(c)))
function MA.operate!(::UnsafeAddMul, res, c)
for (k, v) in nonzero_pairs(c)
unsafe_push!(res, k, v)
end
return mc
return res
end

function MA.operate!(ms::UnsafeAddMul, res, v, w)
for (kv, a) in nonzero_pairs(v)
for (kw, b) in nonzero_pairs(w)
c = ms.structure(kv, kw)
MA.operate!(UnsafeAddMul(*), res, a * b, c)
function MA.operate!(op::UnsafeAddMul, res, b, c, args::Vararg{Any, N}) where {N}
for (kb, vb) in nonzero_pairs(b)
for (kc, vc) in nonzero_pairs(c)
for (k, v) in nonzero_pairs(op.structure(kb, kc))
MA.operate!(
op,
res,
SparseCoefficients((_key(op.structure, k),), (vb * vc * v,)),
args...,
)
end
end
end
return res
Expand Down
3 changes: 3 additions & 0 deletions src/mtables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ function complete!(mt::MTable)
return mt
end

_key(_, k) = k
_key(mstr::MTable, k) = mstr[k]

function MA.operate!(
ms::UnsafeAddMul{<:MTable},
res::AbstractCoefficients,
Expand Down
6 changes: 6 additions & 0 deletions src/sparse_coeffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ function MA.operate!(::typeof(canonical), res::SparseCoefficients)
return MA.operate!(canonical, res, comparable(key_type(res)))
end

function unsafe_push!(res::SparseCoefficients, key, value)
push!(res.basis_elements, key)
push!(res.values, value)
return res
end

# `::C` is needed to force Julia specialize on the function type
# Otherwise, we get one allocation when we call `issorted`
# See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing
Expand Down
7 changes: 7 additions & 0 deletions test/monoid_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@test iszero(zero(fRG))
@test zero(g) == zero(fRG)
@test iszero(0 * g)
@test isone(*(g, g, g))

@testset "Translations between bases" begin
Z = zero(RG)
Expand Down Expand Up @@ -131,6 +132,12 @@

@test @allocated(MA.operate_to!(d, *, a, 2)) == 0
@test d == 2a

MA.operate!(zero, d)
MA.operate!(SA.UnsafeAddMul(*), d, a, b, b)
MA.operate!(SA.canonical, SA.coeffs(d))
@test a * b^2 == *(a, b, b)
@test d == *(a, b, b)
end
end
end

0 comments on commit 8b26f89

Please sign in to comment.