Skip to content

Commit

Permalink
More type stability around simplify (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgoettgens authored Oct 7, 2024
1 parent 403feb3 commit ef7d3a3
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/DeformationBases/DeformBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ function normalize_default(m::DeformationMap{T}) where {T <: SmashProductLieElem
if nz_index === nothing
return m
end
lc = leading_coefficient(m[CartesianIndex(nz_index[2], nz_index[1])].alg_elem)
lc = leading_coefficient(data(m[CartesianIndex(nz_index[2], nz_index[1])]))
m = map(e -> 1 // lc * e, m)
end
1 change: 1 addition & 0 deletions src/PBWDeformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import Oscar.AbstractAlgebra: parent_type

import Oscar: base_lie_algebra
import Oscar: comm
import Oscar: data
import Oscar: edges
import Oscar: n_edges
import Oscar: neighbors
Expand Down
59 changes: 33 additions & 26 deletions src/SmashProductLie.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ underlying_algebra(
Sp::SmashProductLie{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}} = Sp.alg::free_associative_algebra_type(C)

data(e::SmashProductLieElem{C, LieC, LieT}) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}} =
e.alg_elem::elem_type(free_associative_algebra_type(C))

ngens(Sp::SmashProductLie) = ngens(underlying_algebra(Sp))
function ngens(Sp::SmashProductLie, part::Symbol)
part == :L && return dim(base_lie_algebra(Sp))
Expand Down Expand Up @@ -59,19 +62,19 @@ function zero(Sp::SmashProductLie)
end

function iszero(e::SmashProductLieElem)
return iszero(simplify(e).alg_elem)
return iszero(data(simplify(e)))
end

function one(Sp::SmashProductLie)
return Sp(one(underlying_algebra(Sp)))
end

function isone(e::SmashProductLieElem)
return isone(simplify(e).alg_elem)
return isone(data(simplify(e)))
end

function Base.deepcopy_internal(e::SmashProductLieElem, dict::IdDict)
return SmashProductLieElem(parent(e), deepcopy_internal(e.alg_elem, dict); simplified=e.simplified)
return SmashProductLieElem(parent(e), deepcopy_internal(data(e), dict); simplified=e.simplified)
end

function check_parent(
Expand All @@ -89,7 +92,7 @@ function AbstractAlgebra.promote_rule(
end

function change_base_ring(R::Ring, e::SmashProductLieElem{C}; parent::SmashProductLie=smash_product(R, base_lie_algebra(parent(e)), base_module(parent(e)))) where C
return parent(change_base_ring(R, e.alg_elem; parent=underlying_algebra(parent)))
return parent(change_base_ring(R, data(e); parent=underlying_algebra(parent)))
end

###############################################################################
Expand All @@ -112,7 +115,7 @@ end


function show(io::IO, e::SmashProductLieElem)
show(io, e.alg_elem)
show(io, data(e))
end


Expand Down Expand Up @@ -165,57 +168,57 @@ end
###############################################################################

function Base.:-(e::SmashProductLieElem)
return parent(e)(-e.alg_elem)
return parent(e)(-data(e))
end

function Base.:+(e1::SmashProductLieElem, e2::SmashProductLieElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem + e2.alg_elem)
return parent(e1)(data(e1) + data(e2))
end

function Base.:-(e1::SmashProductLieElem, e2::SmashProductLieElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem - e2.alg_elem)
return parent(e1)(data(e1) - data(e2))
end

function Base.:*(e1::SmashProductLieElem, e2::SmashProductLieElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem * e2.alg_elem)
return parent(e1)(data(e1) * data(e2))
end

function Base.:*(e::SmashProductLieElem{C}, c::C) where {C <: RingElem}
coefficient_ring(e) != parent(c) && error("Incompatible rings.")
return parent(e)(e.alg_elem * c)
return parent(e)(data(e) * c)
end

function Base.:*(e::SmashProductLieElem, c::U) where {U <: Union{Rational, IntegerUnion}}
return parent(e)(e.alg_elem * c)
return parent(e)(data(e) * c)
end

function Base.:*(e::SmashProductLieElem{ZZRingElem}, c::ZZRingElem)
return parent(e)(e.alg_elem * c)
return parent(e)(data(e) * c)
end

function Base.:*(c::C, e::SmashProductLieElem{C}) where {C <: RingElem}
coefficient_ring(e) != parent(c) && error("Incompatible rings.")
return parent(e)(c * e.alg_elem)
return parent(e)(c * data(e))
end

function Base.:*(c::U, e::SmashProductLieElem) where {U <: Union{Rational, IntegerUnion}}
return parent(e)(c * e.alg_elem)
return parent(e)(c * data(e))
end

function Base.:*(c::ZZRingElem, e::SmashProductLieElem{ZZRingElem})
return parent(e)(c * e.alg_elem)
return parent(e)(c * data(e))
end

function Base.:^(e::SmashProductLieElem, n::Int)
return parent(e)(e.alg_elem^n)
return parent(e)(data(e)^n)
end

function comm(e1::SmashProductLieElem, e2::SmashProductLieElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem * e2.alg_elem - e2.alg_elem * e1.alg_elem)
return parent(e1)(data(e1) * data(e2) - data(e2) * data(e1))
end

###############################################################################
Expand All @@ -228,14 +231,14 @@ function Base.:(==)(
e1::SmashProductLieElem{C, LieC, LieT},
e2::SmashProductLieElem{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}}
return parent(e1) === parent(e2) && simplify(e1).alg_elem == simplify(e2).alg_elem
return parent(e1) === parent(e2) && data(simplify(e1)) == data(simplify(e2))
end

function Base.hash(e::SmashProductLieElem, h::UInt)
e = simplify(e)
b = 0xdcc11ff793ca4ada % UInt
h = hash(parent(e), h)
h = hash(e.alg_elem, h)
h = hash(data(e), h)
return xor(h, b)
end

Expand All @@ -245,15 +248,17 @@ end
#
###############################################################################

function simplify(e::SmashProductLieElem)
function simplify(
e::SmashProductLieElem{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}}
e.simplified && return e
e.alg_elem = _normal_form(e.alg_elem, parent(e).rels)
e.alg_elem =
_normal_form(data(e), parent(e).rels::Matrix{Union{Nothing, elem_type(free_associative_algebra_type(C))}})
e.simplified = true
return e
end

function _normal_form(a::FreeAssAlgElem{C}, rels::Matrix{Union{Nothing, FreeAssAlgElem{C}}}) where {C <: RingElem}
a = deepcopy(a)
function _normal_form(a::F, rels::Matrix{Union{Nothing, F}}) where {C <: RingElem, F <: FreeAssAlgElem{C}}
result = zero(parent(a))
CR = coefficient_ring(a)
A = parent(a)
Expand All @@ -265,9 +270,11 @@ function _normal_form(a::FreeAssAlgElem{C}, rels::Matrix{Union{Nothing, FreeAssA

changed = false
for i in 1:length(exp)-1
if exp[i] > exp[i+1] && !isnothing(rels[exp[i], exp[i+1]])
exp[i] > exp[i+1] || continue
rel = rels[exp[i], exp[i+1]]
if !isnothing(rel)
changed = true
a += A([c], [exp[1:i-1]]) * rels[exp[i], exp[i+1]] * A([one(CR)], [exp[i+2:end]])
a += A([c], [exp[1:i-1]]) * rel * A([one(CR)], [exp[i+2:end]])
break
end
end
Expand Down Expand Up @@ -321,7 +328,7 @@ function smash_product(R::Ring, L::LieAlgebra{C}, V::LieAlgebraModule{C}) where
f_basisL = [gen(f_alg, i) for i in 1:dimL]
f_basisV = [gen(f_alg, dimL + i) for i in 1:dimV]

rels = Matrix{Union{Nothing, FreeAssAlgElem{elem_type(R)}}}(nothing, dimL + dimV, dimL + dimV)
rels = Matrix{Union{Nothing, elem_type(free_associative_algebra_type(R))}}(nothing, dimL + dimV, dimL + dimV)

for (i, xi) in enumerate(basis(L)), (j, xj) in enumerate(basis(L))
commutator =
Expand Down
60 changes: 34 additions & 26 deletions src/SmashProductLieDeform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ base_module(
D::SmashProductLieDeform{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}} = base_module(D.sp)::LieAlgebraModule{LieC}

underlying_algebra(D::SmashProductLieDeform) = underlying_algebra(D.sp) # TODO: create new algebra for D
underlying_algebra(
D::SmashProductLieDeform{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}} =
underlying_algebra(D.sp)::free_associative_algebra_type(C) # TODO: create new algebra for D

data(e::SmashProductLieDeformElem{C, LieC, LieT}) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}} =
e.alg_elem::elem_type(free_associative_algebra_type(C))

ngens(D::SmashProductLieDeform) = ngens(underlying_algebra(D))
function ngens(D::SmashProductLieDeform, part::Symbol)
Expand Down Expand Up @@ -57,19 +63,19 @@ function zero(D::SmashProductLieDeform)
end

function iszero(e::SmashProductLieDeformElem)
return iszero(simplify(e).alg_elem)
return iszero(data(simplify(e)))
end

function one(D::SmashProductLieDeform)
return D(one(underlying_algebra(D)))
end

function isone(e::SmashProductLieDeformElem)
return isone(simplify(e).alg_elem)
return isone(data(simplify(e)))
end

function Base.deepcopy_internal(e::SmashProductLieDeformElem, dict::IdDict)
return SmashProductLieDeformElem(parent(e), deepcopy_internal(e.alg_elem, dict); simplified=e.simplified)
return SmashProductLieDeformElem(parent(e), deepcopy_internal(data(e), dict); simplified=e.simplified)
end

function check_parent(
Expand Down Expand Up @@ -103,7 +109,7 @@ end


function show(io::IO, e::SmashProductLieDeformElem)
show(io, e.alg_elem)
show(io, data(e))
end


Expand Down Expand Up @@ -139,7 +145,7 @@ function (D::SmashProductLieDeform{C, LieC, LieT})(
e::SmashProductLieElem{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}}
@req parent(e) == D.sp "Incompatible smash products."
return D(e.alg_elem)
return D(data(e))
end

function (D::SmashProductLieDeform{C, LieC, LieT})(
Expand All @@ -163,57 +169,57 @@ end
###############################################################################

function Base.:-(e::SmashProductLieDeformElem)
return parent(e)(-e.alg_elem)
return parent(e)(-data(e))
end

function Base.:+(e1::SmashProductLieDeformElem, e2::SmashProductLieDeformElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem + e2.alg_elem)
return parent(e1)(data(e1) + data(e2))
end

function Base.:-(e1::SmashProductLieDeformElem, e2::SmashProductLieDeformElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem - e2.alg_elem)
return parent(e1)(data(e1) - data(e2))
end

function Base.:*(e1::SmashProductLieDeformElem, e2::SmashProductLieDeformElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem * e2.alg_elem)
return parent(e1)(data(e1) * data(e2))
end

function Base.:*(e::SmashProductLieDeformElem{C}, c::C) where {C <: RingElem}
coefficient_ring(e) != parent(c) && error("Incompatible rings.")
return parent(e)(e.alg_elem * c)
return parent(e)(data(e) * c)
end

function Base.:*(e::SmashProductLieDeformElem, c::U) where {U <: Union{Rational, IntegerUnion}}
return parent(e)(e.alg_elem * c)
return parent(e)(data(e) * c)
end

function Base.:*(e::SmashProductLieDeformElem{ZZRingElem}, c::ZZRingElem)
return parent(e)(e.alg_elem * c)
return parent(e)(data(e) * c)
end

function Base.:*(c::C, e::SmashProductLieDeformElem{C}) where {C <: RingElem}
coefficient_ring(e) != parent(c) && error("Incompatible rings.")
return parent(e)(c * e.alg_elem)
return parent(e)(c * data(e))
end

function Base.:*(c::U, e::SmashProductLieDeformElem) where {U <: Union{Rational, IntegerUnion}}
return parent(e)(c * e.alg_elem)
return parent(e)(c * data(e))
end

function Base.:*(c::ZZRingElem, e::SmashProductLieDeformElem{ZZRingElem})
return parent(e)(c * e.alg_elem)
return parent(e)(c * data(e))
end

function Base.:^(e::SmashProductLieDeformElem, n::Int)
return parent(e)(e.alg_elem^n)
return parent(e)(data(e)^n)
end

function comm(e1::SmashProductLieDeformElem, e2::SmashProductLieDeformElem)
check_parent(e1, e2)
return parent(e1)(e1.alg_elem * e2.alg_elem - e2.alg_elem * e1.alg_elem)
return parent(e1)(data(e1) * data(e2) - data(e2) * data(e1))
end

###############################################################################
Expand All @@ -226,14 +232,14 @@ function Base.:(==)(
e1::SmashProductLieDeformElem{C, LieC, LieT},
e2::SmashProductLieDeformElem{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}}
return parent(e1) === parent(e2) && simplify(e1).alg_elem == simplify(e2).alg_elem
return parent(e1) === parent(e2) && data(simplify(e1)) == data(simplify(e2))
end

function Base.hash(e::SmashProductLieDeformElem, h::UInt)
e = simplify(e)
b = 0x97eb07aa70e4a59c % UInt
h = hash(parent(e), h)
h = hash(e.alg_elem, h)
h = hash(data(e), h)
return xor(h, b)
end

Expand All @@ -243,9 +249,12 @@ end
#
###############################################################################

function simplify(e::SmashProductLieDeformElem)
function simplify(
e::SmashProductLieDeformElem{C, LieC, LieT},
) where {C <: RingElem, LieC <: FieldElem, LieT <: LieAlgebraElem{LieC}}
e.simplified && return e
e.alg_elem = _normal_form(e.alg_elem, parent(e).rels)
e.alg_elem =
_normal_form(data(e), parent(e).rels::Matrix{Union{Nothing, elem_type(free_associative_algebra_type(C))}})
e.simplified = true
return e
end
Expand Down Expand Up @@ -277,16 +286,15 @@ function deform(

for i in 1:dimV, j in 1:i
@req kappa[i, j] == -kappa[j, i] "kappa is not skew-symmetric."
@req all(<=(dimL), Iterators.flatten(exponent_words(kappa[i, j].alg_elem))) "kappa does not only take values in the hopf algebra"
@req all(<=(dimL), Iterators.flatten(exponent_words(kappa[j, i].alg_elem))) "kappa does not only take values in the hopf algebra"
@req all(<=(dimL), Iterators.flatten(exponent_words(data(kappa[i, j])))) "kappa does not only take values in the hopf algebra"
end

symmetric = true
rels = deepcopy(sp.rels)
rels = deepcopy(sp.rels::Matrix{Union{Nothing, elem_type(free_associative_algebra_type(C))}})
for i in 1:dimV, j in 1:dimV
# We have the commutator relation [v_i, v_j] = kappa[i,j]
# which is equivalent to v_i*v_j = v_j*v_i + kappa[i,j]
rels[dimL+i, dimL+j] = basisV[j] * basisV[i] + simplify(kappa[i, j]).alg_elem
rels[dimL+i, dimL+j] = basisV[j] * basisV[i] + data(simplify(kappa[i, j]))
symmetric &= iszero(kappa[i, j])
end

Expand Down
2 changes: 1 addition & 1 deletion src/SmashProductPBWDeformLie.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function all_pbwdeformations(
i = x[1]
a = x[2]
@vprintln :PBWDeformations 2 "Equations $(lpad(floor(Int, 100*i / neqs), 3))%, $(lpad(i, ndigits(neqs)))/$(neqs)"
coefficient_comparison(simplify(a).alg_elem)
coefficient_comparison(data(simplify(a)))
end,
enumerate(pbwdeform_eqs(d)),
),
Expand Down
Loading

0 comments on commit ef7d3a3

Please sign in to comment.