Skip to content

Commit

Permalink
update cuda, fix BigFloat, bump minor version
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed May 7, 2021
1 parent 756dfad commit db7acc9
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
authors = ["Jutho Haegeman"]
version = "3.1.0"
version = "3.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -16,7 +16,7 @@ Strided = "1"
TupleTools = "1.1"
LRUCache = "1"
Requires = "0.5,1"
CUDA = "1,2"
CUDA = "1,2,3"
julia = "1.4"

[extras]
Expand Down
12 changes: 10 additions & 2 deletions src/implementation/tensorcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@ memsize(a::Any) = Base.summarysize(a)

# generic definitions, should be overwritten if your array/tensor type does not support
# Base.similar(object, eltype, structure)
function similar_from_structure(A, T, structure)
if isbits(T)
similar(A, T, structure)
else
fill!(similar(A, T, structure), zero(T)) # this fixes BigFloat issues
end
end

function similar_from_indices(T::Type, p1::IndexTuple, p2::IndexTuple, A, CA::Symbol)
structure = similarstructure_from_indices(T, p1, p2, A, CA)
similar(A, T, structure)
similar_from_structure(A, T, structure)
end
function similar_from_indices(T::Type, poA::IndexTuple, poB::IndexTuple,
p1::IndexTuple, p2::IndexTuple,
A, B, CA::Symbol, CB::Symbol)
structure = similarstructure_from_indices(T, poA, poB, p1, p2, A, B, CA, CB)
similar(A, T, structure)
similar_from_structure(A, T, structure)
end

# should work generically but can be overwritten
Expand Down
12 changes: 6 additions & 6 deletions src/indexnotation/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mutable struct TensorParser

postprocessors::Vector{Any}
function TensorParser()
preprocessors = [ex->replaceindices(normalizeindex, ex),
preprocessors = [normalizeindices,
expandconj,
nconindexcompletion,
extracttensorobjects]
Expand All @@ -19,19 +19,19 @@ mutable struct TensorParser
end
end

function (parser::TensorParser)(ex)
function (parser::TensorParser)(ex::Expr)
if ex isa Expr && ex.head == :function
return Expr(:function, ex.args[1], parser(ex.args[2]))
end
for p in parser.preprocessors
ex = p(ex)
ex = p(ex)::Expr
end
treebuilder = parser.contractiontreebuilder
treesorter = parser.contractiontreesorter
ex = processcontractions(ex, treebuilder, treesorter)
ex = tensorify(ex)
ex = processcontractions(ex, treebuilder, treesorter)::Expr
ex = tensorify(ex)::Expr
for p in parser.postprocessors
ex = p(ex)
ex = p(ex)::Expr
end
return ex
end
Expand Down
8 changes: 5 additions & 3 deletions src/indexnotation/preprocessors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# replace all indices by a function of that index
function replaceindices(f, ex::Expr)
function replaceindices((@nospecialize f), ex::Expr)
if istensor(ex)
if ex.head == :ref || ex.head == :typed_hcat
if length(ex.args) == 1
Expand Down Expand Up @@ -28,7 +28,7 @@ function replaceindices(f, ex::Expr)
return Expr(ex.head, (replaceindices(f, e) for e in ex.args)...)
end
end
replaceindices(f, ex) = ex
replaceindices((@nospecialize f), ex) = ex

function normalizeindex(ex)
if isa(ex, Symbol) || isa(ex, Int)
Expand All @@ -40,6 +40,8 @@ function normalizeindex(ex)
end
end

normalizeindices(ex::Expr) = replaceindices(normalizeindex, ex)

# replace all tensor objects by a function of that object
function replacetensorobjects(f, ex::Expr)
# first try to replace ex completely
Expand Down Expand Up @@ -97,7 +99,7 @@ explicitscalar(ex) = ex

# extracttensorobjects: replace all tensor objects with newly generated symbols, and assign
# them before the expression and after the expression as necessary
function extracttensorobjects(ex)
function extracttensorobjects(ex::Expr)
inputtensors = getinputtensorobjects(ex)
outputtensors = getoutputtensorobjects(ex)
newtensors = getnewtensorobjects(ex)
Expand Down
22 changes: 11 additions & 11 deletions test/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ withcache = TensorOperations.use_cache() ? "with" : "without"

# Simple function example
@tensor function f(A, b)
w[x] := A[x,y]*b[y]
w[x] := (1//2)*A[x,y]*b[y]
return w
end
for T in (Float32, Float64, ComplexF32, ComplexF64, BigFloat)
A = rand(T, 10, 10)
b = rand(T, 10)
@test f(A,b) A*b
@test f(A,b) (1//2)*A*b
end

# Example from README.md
Expand All @@ -241,14 +241,14 @@ withcache = TensorOperations.use_cache() ? "with" : "without"
t0 = time()

# Some tensor network examples
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset for T in (Float32, Float64, ComplexF32, ComplexF64, BigFloat)
D1, D2, D3 = 30, 40, 20
d1, d2 = 2, 3
A1 = randn(T, D1, d1, D2)
A2 = randn(T, D2, d2, D3)
rhoL = randn(T, D1, D1)
rhoR = randn(T, D3, D3)
H = randn(T, d1, d2, d1, d2)
A1 = rand(T, D1, d1, D2) .- 1//2
A2 = rand(T, D2, d2, D3) .- 1//2
rhoL = rand(T, D1, D1) .- 1//2
rhoR = rand(T, D3, D3) .- 1//2
H = rand(T, d1, d2, d1, d2) .- 1//2
A12 = reshape(reshape(A1, D1 * d1, D2) * reshape(A2, D2, d2 * D3), (D1, d1, d2, D3))
rA12 = reshape(reshape(rhoL * reshape(A12, (D1, d1*d2*D3)), (D1*d1*d2, D3)) * rhoR, (D1, d1, d2, D3))
HrA12 = permutedims(reshape(reshape(H, (d1 * d2, d1*d2)) * reshape(permutedims(rA12, (2,3,1,4)), (d1 * d2, D1 * D3)), (d1, d2, D1, D3)), (3,1,2,4))
Expand Down Expand Up @@ -304,11 +304,11 @@ withcache = TensorOperations.use_cache() ? "with" : "without"
op2 = randn(2, 2)
op3 = randn(2, 2)

f(op,op3) = @ncon((op, op3), ([-1 3], [3 -3]))
f83(op,op3) = @ncon((op, op3), ([-1 3], [3 -3]))

b = f(op1,op3)
b = f83(op1,op3)
bcopy = deepcopy(b)
c = f(op2,op3)
c = f83(op2,op3)
@test b == bcopy
@test b != c
end
Expand Down

0 comments on commit db7acc9

Please sign in to comment.