diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 17406b6a..01c0f94c 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -78,22 +78,19 @@ end throw(SectorMismatch()) end @inbounds begin - d = (dims(V2 ⊗ V1, f₁.uncoupled)..., dims(V1 ⊗ V2, f₂.uncoupled)...) + d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...) n1 = d[1] * d[2] n2 = d[3] * d[4] - data = storagetype(b)(undef, (n1, n2)) + data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d) fill!(data, zero(eltype(b))) - a1, a2 = f₂.uncoupled - if f₁.uncoupled == (a2, a1) + if f₁.uncoupled == reverse(f₂.uncoupled) braiddict = artin_braid(f₂, 1; inv=b.adjoint) r = get(braiddict, f₁, zero(valtype(braiddict))) - si = 1 + d[1] * d[2] * d[3] - sj = d[1] + d[1] * d[2] - @inbounds for i in 1:d[1], j in 1:d[2] - data[(i - 1) * si + (j - 1) * sj + 1] = r + @inbounds for i in axes(data, 1), j in axes(data, 2) + data[i, j, j, i] = r end end - return sreshape(StridedView(data), d) + return data end end @inline function Base.getindex(b::BraidingTensor, ::Nothing, ::Nothing) @@ -104,31 +101,9 @@ end # efficient copy constructor Base.copy(b::BraidingTensor) = b -function Base.copy!(t::TensorMap, b::BraidingTensor) - space(t) == space(b) || throw(SectorMismatch()) - fill!(t, zero(scalartype(t))) - for (f₁, f₂) in fusiontrees(t) - data = t[f₁, f₂] - if sectortype(t) == Trivial - r = one(scalartype(t)) - else - a1, a2 = f₂.uncoupled - c = f₂.coupled - f₁.uncoupled == (a2, a1) || continue - braiddict = artin_braid(f₂, 1; inv=b.adjoint) - r = convert(scalartype(t), get(braiddict, f₁, zero(valtype(braiddict)))) - end - @inbounds for i in axes(data, 1), j in axes(data, 2) - data[i, j, j, i] = r - end - end - return t -end TensorMap(b::BraidingTensor) = copy!(similar(b), b) Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b) -# TODO: fix this! -# block(b::BraidingTensor, s::Sector) = block(TensorMap(b), s) function block(b::BraidingTensor, s::Sector) sectortype(b) == typeof(s) || throw(SectorMismatch()) @@ -141,7 +116,7 @@ function block(b::BraidingTensor, s::Sector) data = fill!(data, zero(eltype(b))) - V1, V2 = domain(b) + V1, V2 = codomain(b) if sectortype(b) === Trivial d1, d2 = dim(V1), dim(V2) subblock = sreshape(StridedView(data), (d1, d2, d2, d1)) @@ -174,8 +149,6 @@ function block(b::BraidingTensor, s::Sector) return data end -blocks(b::BraidingTensor) = blocks(TensorMap(b)) - # Index manipulations # ------------------- has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 6caef06d..afa94e66 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -69,8 +69,34 @@ function TreeTransformer(transform::Function, Vsrc::HomSpace{S}, end end +# braid is special because it has levels +const treebraidercache = LRU{Any,Any}(; maxsize=10^5) +const usetreebraidercache = Ref{Bool}(true) +@noinline function _get_treebraider(A, key) + d::A = get!(treebraidercache, key) do + return _treebraider(key) + end + return d +end +function _treebraider((Vdst, Vsrc, p, levels)) + fusiontreebraider(f1, f2) = braid(f1, f2, levels..., p...) + return TreeTransformer(fusiontreebraider, Vsrc, Vdst) +end +function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p, levels) + return fusiontreetransform(f1, f2) = braid(f1, f2, levels..., p...) +end +function treebraider(tdst::TensorMap, tsrc::TensorMap, p, levels) + if usetreebraidercache[] + key = (space(tdst), space(tsrc), p, levels) + A = treetransformertype(space(tdst), space(tsrc)) + return _get_treebraider(A, key) + else + return _treebraider((space(tdst), space(tsrc), p, levels)) + end +end + for (transform, transformer) in - ((:permute, :permuter), (:braid, :braider), (:transpose, :transposer)) + ((:permute, :permuter), (:transpose, :transposer)) treetransformcache = Symbol("tree", transformer, "cache") usetreetransformcache = Symbol("usetree", transformer, "cache") treetransformer = Symbol("tree", transformer) diff --git a/test/planar.jl b/test/planar.jl index 0487404f..c9e1c3c4 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -30,22 +30,59 @@ function force_planar(tsrc::TensorMap{<:Any,<:GradedSpace}) return tdst end +Vtr = (ℂ^3, + (ℂ^2)', + ℂ^5, + ℂ^6, + (ℂ^7)') +VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), + ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), + ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3), + ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)') +VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2), + ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1), + ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)', + ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3), + ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)') +VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1), + ℂ[FermionSpin](0 => 2, 1 => 1), + ℂ[FermionSpin](1 // 2 => 1, 1 => 1)', + ℂ[FermionSpin](0 => 2, 1 // 2 => 2), + ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') +Vfib = (Vect[FibonacciAnyon](:I => 1, :τ => 2), + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 1)) @testset "Braiding tensor" begin - V1 = ℂ^2 ⊗ ℂ^3 ← ℂ^3 ⊗ ℂ^2 - t1 = @constinferred BraidingTensor(V1) - @test space(t1) == V1 - @test codomain(t1) == codomain(V1) - @test domain(t1) == domain(V1) - @test scalartype(t1) == Float64 - @test storagetype(t1) == Vector{Float64} - t2 = @constinferred BraidingTensor{ComplexF64}(V1) - @test scalartype(t2) == ComplexF64 - @test storagetype(t2) == Vector{ComplexF64} - - V2 = ℂ^2 ⊗ ℂ^3 ← ℂ^2 ⊗ ℂ^3 - @test_throws SpaceMismatch BraidingTensor(V2) - - @test adjoint(t1) isa BraidingTensor + for V in (Vtr, VU₁, VfU₁, VfSU₂, Vfib) + W = V[1] ⊗ V[2] ← V[2] ⊗ V[1] + t1 = @constinferred BraidingTensor(W) + @test space(t1) == W + @test codomain(t1) == codomain(W) + @test domain(t1) == domain(W) + @test scalartype(t1) == (isreal(sectortype(W)) ? Float64 : ComplexF64) + @test storagetype(t1) == Vector{scalartype(t1)} + t2 = @constinferred BraidingTensor{ComplexF64}(W) + @test scalartype(t2) == ComplexF64 + @test storagetype(t2) == Vector{ComplexF64} + + W2 = reverse(codomain(W)) ← domain(W) + @test_throws SpaceMismatch BraidingTensor(W2) + + @test adjoint(t1) isa BraidingTensor + + t3 = @inferred TensorMap(t2) + t4 = braid(id(storagetype(t2), domain(t2)), ((2, 1), (3, 4)), (1, 2, 3, 4)) + @test t1 ≈ t4 + for (c, b) in blocks(t1) + @test block(t1, c) ≈ b ≈ block(t3, c) + end + for (f1, f2) in fusiontrees(t1) + @test t1[f1, f2] ≈ t3[f1, f2] + end + end end @testset "planar methods" verbose = true begin