Skip to content

Commit

Permalink
Fix converting empty tensors to array (#180)
Browse files Browse the repository at this point in the history
* Fix conversion of empty tensor

* Add tests for empty tensor conversion

* Add `zero(::ElementarySpace)`

* Fix failing tests

* Add explicit test for issue #178

* Fix wrong type

* `eltype` should determine precision

* Add entry to docs

* small fix

* small fix attempt II

Fixes #178
  • Loading branch information
lkdvos authored Nov 21, 2024
1 parent 8b38973 commit f7eabe5
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/src/lib/spaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ dual
conj
flip
zero(::ElementarySpace)
oneunit
supremum
infimum
Expand Down
1 change: 1 addition & 0 deletions src/spaces/cartesianspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ sectors(V::CartesianSpace) = OneOrNoneIterator(dim(V) != 0, Trivial())
sectortype(::Type{CartesianSpace}) = Trivial

Base.oneunit(::Type{CartesianSpace}) = CartesianSpace(1)
Base.zero(::Type{CartesianSpace}) = CartesianSpace(0)
(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d + V₂.d)
fuse(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d * V₂.d)
flip(V::CartesianSpace) = V
Expand Down
1 change: 1 addition & 0 deletions src/spaces/complexspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ sectortype(::Type{ComplexSpace}) = Trivial
Base.conj(V::ComplexSpace) = ComplexSpace(dim(V), !isdual(V))

Base.oneunit(::Type{ComplexSpace}) = ComplexSpace(1)
Base.zero(::Type{ComplexSpace}) = ComplexSpace(0)
function (V₁::ComplexSpace, V₂::ComplexSpace)
return isdual(V₁) == isdual(V₂) ?
ComplexSpace(dim(V₁) + dim(V₂), isdual(V₁)) :
Expand Down
3 changes: 3 additions & 0 deletions src/spaces/generalspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ sectortype(::Type{<:GeneralSpace}) = Trivial
field(::Type{GeneralSpace{𝔽}}) where {𝔽} = 𝔽
InnerProductStyle(::Type{<:GeneralSpace}) = NoInnerProduct()

Base.oneunit(::Type{GeneralSpace{𝔽}}) where {𝔽} = GeneralSpace{𝔽}(1, false, false)
Base.zero(::Type{GeneralSpace{𝔽}}) where {𝔽} = GeneralSpace{𝔽}(0, false, false)

dual(V::GeneralSpace{𝔽}) where {𝔽} = GeneralSpace{𝔽}(dim(V), !isdual(V), isconj(V))
Base.conj(V::GeneralSpace{𝔽}) where {𝔽} = GeneralSpace{𝔽}(dim(V), isdual(V), !isconj(V))

Expand Down
1 change: 1 addition & 0 deletions src/spaces/gradedspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ function Base.axes(V::GradedSpace{I}, c::I) where {I<:Sector}
end

Base.oneunit(S::Type{<:GradedSpace{I}}) where {I<:Sector} = S(one(I) => 1)
Base.zero(S::Type{<:GradedSpace{I}}) where {I<:Sector} = S(one(I) => 0)

# TODO: the following methods can probably be implemented more efficiently for
# `FiniteGradedSpace`, but we don't expect them to be used often in hot loops, so
Expand Down
8 changes: 8 additions & 0 deletions src/spaces/vectorspaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ that this is different from `one(V::S)`, which returns the empty product space
"""
Base.oneunit(V::ElementarySpace) = oneunit(typeof(V))

"""
zero(V::S) where {S<:ElementarySpace} -> S
Return the corresponding vector space of type `S` that represents the zero-dimensional or empty space.
This is, with a slight abuse of notation, the zero element of the direct sum of vector spaces.
"""
Base.zero(V::ElementarySpace) = zero(typeof(V))

"""
⊕(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S
oplus(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S
Expand Down
16 changes: 4 additions & 12 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,21 +497,13 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
else
cod = codomain(t)
dom = domain(t)
local A
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
A = zeros(T, dims(cod)..., dims(dom)...)
for (f₁, f₂) in fusiontrees(t)
F = convert(Array, (f₁, f₂))
if !(@isdefined A)
if eltype(F) <: Complex
T = complex(float(scalartype(t)))
elseif eltype(F) <: Integer
T = scalartype(t)
else
T = float(scalartype(t))
end
A = fill(zero(T), (dims(cod)..., dims(dom)...))
end
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
axpy!(1, StridedView(_kron(convert(Array, t[f₁, f₂]), F)), Aslice)
add!(Aslice, StridedView(_kron(convert(Array, t[f₁, f₂]), F)))
end
return A
end
Expand Down
7 changes: 7 additions & 0 deletions test/bugfixes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,11 @@
@test w == v
@test scalartype(w) == Float64
end

# https://github.com/Jutho/TensorKit.jl/issues/178
@testset "Issue #178" begin
t = rand(U1Space(1 => 1) U1Space(1 => 1)')
a = convert(Array, t)
@test a == zeros(size(a))
end
end
16 changes: 11 additions & 5 deletions test/spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ println("------------------------------------")
@test length(sectors(V)) == 1
@test @constinferred(TensorKit.hassector(V, Trivial()))
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
@test dim(@constinferred(typeof(V)())) == 0
@test (sectors(typeof(V)())...,) == ()
@test dim(@constinferred(zero(V))) == 0
@test (sectors(zero(V))...,) == ()
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
@test^d == ℝ[](d) == CartesianSpace(d) == typeof(V)(d)
W = @constinferred^1
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
@test @constinferred(zero(V)) ==^0 == zero(typeof(V))
@test @constinferred((V, zero(V))) == V
@test @constinferred((V, V)) ==^(2d)
@test @constinferred((V, oneunit(V))) ==^(d + 1)
@test @constinferred((V, V, V, V)) ==^(4d)
Expand Down Expand Up @@ -111,12 +113,14 @@ println("------------------------------------")
@test length(sectors(V)) == 1
@test @constinferred(TensorKit.hassector(V, Trivial()))
@test @constinferred(dim(V)) == d == @constinferred(dim(V, Trivial()))
@test dim(@constinferred(typeof(V)())) == 0
@test (sectors(typeof(V)())...,) == ()
@test dim(@constinferred(zero(V))) == 0
@test (sectors(zero(V))...,) == ()
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(d)
@test^d == Vect[Trivial](d) == Vect[](Trivial() => d) == ℂ[](d) == typeof(V)(d)
W = @constinferred^1
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
@test @constinferred(zero(V)) ==^0 == zero(typeof(V))
@test @constinferred((V, zero(V))) == V
@test @constinferred((V, V)) ==^(2d)
@test_throws SpaceMismatch ((V, V'))
# promote_except = ErrorException("promotion of types $(typeof(ℝ^d)) and " *
Expand Down Expand Up @@ -200,11 +204,12 @@ println("------------------------------------")
@test eval(Meta.parse(sprint(show, V))) == V
@test eval(Meta.parse(sprint(show, typeof(V)))) == typeof(V)
# space with no sectors
@test dim(@constinferred(typeof(V)())) == 0
@test dim(@constinferred(zero(V))) == 0
# space with a single sector
W = @constinferred GradedSpace(one(I) => 1)
@test W == GradedSpace(one(I) => 1, randsector(I) => 0)
@test @constinferred(oneunit(V)) == W == oneunit(typeof(V))
@test @constinferred(zero(V)) == GradedSpace(one(I) => 0)
# randsector never returns trivial sector, so this cannot error
@test_throws ArgumentError GradedSpace(one(I) => 1, randsector(I) => 0, one(I) => 3)
@test eval(Meta.parse(sprint(show, W))) == W
Expand All @@ -226,6 +231,7 @@ println("------------------------------------")
if hasfusiontensor(I)
@test @constinferred(TensorKit.axes(V)) == Base.OneTo(dim(V))
end
@test @constinferred((V, zero(V))) == V
@test @constinferred((V, V)) == Vect[I](c => 2dim(V, c) for c in sectors(V))
@test @constinferred((V, V, V, V)) == Vect[I](c => 4dim(V, c) for c in sectors(V))
@test @constinferred((V, oneunit(V))) ==
Expand Down
7 changes: 6 additions & 1 deletion test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ for V in spacelist
@test t === @constinferred TensorMap(t.data, W)
end
end
for T in (Int, Float32, ComplexF64)
t = randn(T, V1 V2 zero(V1))
a = convert(Array, t)
@test norm(a) == 0
end
end
end
@timedtestset "Basic linear algebra" begin
Expand Down Expand Up @@ -466,7 +471,7 @@ for V in spacelist
end
end
@testset "empty tensor" begin
t = randn(T, V1 V2, typeof(V1)())
t = randn(T, V1 V2, zero(V1))
@testset "leftorth with $alg" for alg in
(TensorKit.QR(), TensorKit.QRpos(),
TensorKit.QL(), TensorKit.QLpos(),
Expand Down

0 comments on commit f7eabe5

Please sign in to comment.