Skip to content

Commit

Permalink
Braidingtensor improvements (#179)
Browse files Browse the repository at this point in the history
* Make braidingtensor behave

* fix `treebraider`

* expand BraidingTensor tests

* Remove duplicate copy

* rename `V1 -> W`

* Apply suggestions from code review

* Remove stackoverflow
  • Loading branch information
lkdvos authored Nov 20, 2024
1 parent f467a21 commit 8b38973
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
41 changes: 7 additions & 34 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,19 @@ end
throw(SectorMismatch())
end
@inbounds begin
d = (dims(V2 V1, f₁.uncoupled)..., dims(V1 V2, f₂.uncoupled)...)
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
n1 = d[1] * d[2]
n2 = d[3] * d[4]
data = storagetype(b)(undef, (n1, n2))
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
fill!(data, zero(eltype(b)))
a1, a2 = f₂.uncoupled
if f₁.uncoupled == (a2, a1)
if f₁.uncoupled == reverse(f₂.uncoupled)
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
r = get(braiddict, f₁, zero(valtype(braiddict)))
si = 1 + d[1] * d[2] * d[3]
sj = d[1] + d[1] * d[2]
@inbounds for i in 1:d[1], j in 1:d[2]
data[(i - 1) * si + (j - 1) * sj + 1] = r
@inbounds for i in axes(data, 1), j in axes(data, 2)
data[i, j, j, i] = r
end
end
return sreshape(StridedView(data), d)
return data
end
end
@inline function Base.getindex(b::BraidingTensor, ::Nothing, ::Nothing)
Expand All @@ -104,31 +101,9 @@ end
# efficient copy constructor
Base.copy(b::BraidingTensor) = b

function Base.copy!(t::TensorMap, b::BraidingTensor)
space(t) == space(b) || throw(SectorMismatch())
fill!(t, zero(scalartype(t)))
for (f₁, f₂) in fusiontrees(t)
data = t[f₁, f₂]
if sectortype(t) == Trivial
r = one(scalartype(t))
else
a1, a2 = f₂.uncoupled
c = f₂.coupled
f₁.uncoupled == (a2, a1) || continue
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
r = convert(scalartype(t), get(braiddict, f₁, zero(valtype(braiddict))))
end
@inbounds for i in axes(data, 1), j in axes(data, 2)
data[i, j, j, i] = r
end
end
return t
end
TensorMap(b::BraidingTensor) = copy!(similar(b), b)
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)

# TODO: fix this!
# block(b::BraidingTensor, s::Sector) = block(TensorMap(b), s)
function block(b::BraidingTensor, s::Sector)
sectortype(b) == typeof(s) || throw(SectorMismatch())

Expand All @@ -141,7 +116,7 @@ function block(b::BraidingTensor, s::Sector)

data = fill!(data, zero(eltype(b)))

V1, V2 = domain(b)
V1, V2 = codomain(b)
if sectortype(b) === Trivial
d1, d2 = dim(V1), dim(V2)
subblock = sreshape(StridedView(data), (d1, d2, d2, d1))
Expand Down Expand Up @@ -174,8 +149,6 @@ function block(b::BraidingTensor, s::Sector)
return data
end

blocks(b::BraidingTensor) = blocks(TensorMap(b))

# Index manipulations
# -------------------
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
Expand Down
28 changes: 27 additions & 1 deletion src/tensors/treetransformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,34 @@ function TreeTransformer(transform::Function, Vsrc::HomSpace{S},
end
end

# braid is special because it has levels
const treebraidercache = LRU{Any,Any}(; maxsize=10^5)
const usetreebraidercache = Ref{Bool}(true)
@noinline function _get_treebraider(A, key)
d::A = get!(treebraidercache, key) do
return _treebraider(key)
end
return d
end
function _treebraider((Vdst, Vsrc, p, levels))
fusiontreebraider(f1, f2) = braid(f1, f2, levels..., p...)
return TreeTransformer(fusiontreebraider, Vsrc, Vdst)
end
function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p, levels)
return fusiontreetransform(f1, f2) = braid(f1, f2, levels..., p...)
end
function treebraider(tdst::TensorMap, tsrc::TensorMap, p, levels)
if usetreebraidercache[]
key = (space(tdst), space(tsrc), p, levels)
A = treetransformertype(space(tdst), space(tsrc))
return _get_treebraider(A, key)
else
return _treebraider((space(tdst), space(tsrc), p, levels))
end
end

for (transform, transformer) in
((:permute, :permuter), (:braid, :braider), (:transpose, :transposer))
((:permute, :permuter), (:transpose, :transposer))
treetransformcache = Symbol("tree", transformer, "cache")
usetreetransformcache = Symbol("usetree", transformer, "cache")
treetransformer = Symbol("tree", transformer)
Expand Down
67 changes: 52 additions & 15 deletions test/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,59 @@ function force_planar(tsrc::TensorMap{<:Any,<:GradedSpace})
return tdst
end

Vtr = (ℂ^3,
(ℂ^2)',
^5,
^6,
(ℂ^7)')
VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2),
ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1),
ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)',
ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3),
ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)')
VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2),
ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1),
ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)',
ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3),
ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)')
VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1),
ℂ[FermionSpin](0 => 2, 1 => 1),
ℂ[FermionSpin](1 // 2 => 1, 1 => 1)',
ℂ[FermionSpin](0 => 2, 1 // 2 => 2),
ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)')
Vfib = (Vect[FibonacciAnyon](:I => 1, => 2),
Vect[FibonacciAnyon](:I => 2, => 1),
Vect[FibonacciAnyon](:I => 1, => 1),
Vect[FibonacciAnyon](:I => 1, => 1),
Vect[FibonacciAnyon](:I => 1, => 1))
@testset "Braiding tensor" begin
V1 =^2 ^3 ^3 ^2
t1 = @constinferred BraidingTensor(V1)
@test space(t1) == V1
@test codomain(t1) == codomain(V1)
@test domain(t1) == domain(V1)
@test scalartype(t1) == Float64
@test storagetype(t1) == Vector{Float64}
t2 = @constinferred BraidingTensor{ComplexF64}(V1)
@test scalartype(t2) == ComplexF64
@test storagetype(t2) == Vector{ComplexF64}

V2 =^2 ^3 ^2 ^3
@test_throws SpaceMismatch BraidingTensor(V2)

@test adjoint(t1) isa BraidingTensor
for V in (Vtr, VU₁, VfU₁, VfSU₂, Vfib)
W = V[1] V[2] V[2] V[1]
t1 = @constinferred BraidingTensor(W)
@test space(t1) == W
@test codomain(t1) == codomain(W)
@test domain(t1) == domain(W)
@test scalartype(t1) == (isreal(sectortype(W)) ? Float64 : ComplexF64)
@test storagetype(t1) == Vector{scalartype(t1)}
t2 = @constinferred BraidingTensor{ComplexF64}(W)
@test scalartype(t2) == ComplexF64
@test storagetype(t2) == Vector{ComplexF64}

W2 = reverse(codomain(W)) domain(W)
@test_throws SpaceMismatch BraidingTensor(W2)

@test adjoint(t1) isa BraidingTensor

t3 = @inferred TensorMap(t2)
t4 = braid(id(storagetype(t2), domain(t2)), ((2, 1), (3, 4)), (1, 2, 3, 4))
@test t1 t4
for (c, b) in blocks(t1)
@test block(t1, c) b block(t3, c)
end
for (f1, f2) in fusiontrees(t1)
@test t1[f1, f2] t3[f1, f2]
end
end
end

@testset "planar methods" verbose = true begin
Expand Down

0 comments on commit 8b38973

Please sign in to comment.