diff --git a/src/mstructures.jl b/src/mstructures.jl index 41267fd..88fb870 100644 --- a/src/mstructures.jl +++ b/src/mstructures.jl @@ -28,6 +28,10 @@ When the product is not representable faithfully, """ abstract type MultiplicativeStructure end +struct UnsafeAddMul{M<:MultiplicativeStructure} + structure::M +end + function MA.operate_to!( res::SparseCoefficients, ms::MultiplicativeStructure, @@ -35,20 +39,20 @@ function MA.operate_to!( w::AbstractCoefficients, ) MA.operate!(zero, res) - res = fmac!(ms, res, v, w) + res = MA.operate_to!(res, UnsafeAddMul(ms), v, w) __canonicalize!(res) return res end -function fmac!( - ms::MultiplicativeStructure, +function MA.operate_to!( res::SparseCoefficients, + ms::UnsafeAddMul, v::AbstractCoefficients, w::AbstractCoefficients, ) for (kv, a) in pairs(v) for (kw, b) in pairs(w) - c = ms(kv, kw) # ::AbstractCoefficients + c = ms.structure(kv, kw) # ::AbstractCoefficients unsafe_append!(res, c => a * b) end end @@ -61,8 +65,12 @@ function MA.operate_to!( X::AbstractVector, Y::AbstractVector, ) - res = (res === X || res === Y) ? zero(res) : (res .= zero(eltype(res))) - fmac!(ms, res, X, Y) + if res === X || res === Y + res = zero(res) + else + MA.operate!(zero, res) + end + MA.operate_to!(res, UnsafeAddMul(ms), X, Y) res = __canonicalize!(res) return res end diff --git a/src/mtables.jl b/src/mtables.jl index b470847..e3359e4 100644 --- a/src/mtables.jl +++ b/src/mtables.jl @@ -81,9 +81,9 @@ function complete!(mt::MTable) return mt end -function fmac!( - ms::MTable, +function MA.operate_to!( res::AbstractSparseVector, + ms::UnsafeAddMul{<:MTable}, v::AbstractVector, w::AbstractVector, ) @@ -93,9 +93,9 @@ function fmac!( for (kv, a) in _nzpairs(v) for (kw, b) in _nzpairs(w) - c = ms(kv, kw) + c = ms.structure(kv, kw) for (k, v) in pairs(c) - push!(idcs, ms[k]) + push!(idcs, ms.structure[k]) push!(vals, v * a * b) end end