Skip to content

Commit

Permalink
Use Ref instead of arrays in CUTENSOR wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Sep 8, 2020
1 parent e51c745 commit 100ed9f
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions lib/cutensor/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ function elementwiseTrinary!(
modeC = collect(Cint, Cinds)
modeD = modeC
cutensorElementwiseTrinary(handle(),
T[alpha], A, descA, modeA,
T[beta], B, descB, modeB,
T[gamma], C, descC, modeC,
Ref{T}(alpha), A, descA, modeA,
Ref{T}(beta), B, descB, modeB,
Ref{T}(gamma), C, descC, modeC,
D, descD, modeD,
opAB, opABC, T, stream)
return D
Expand Down Expand Up @@ -95,9 +95,9 @@ function elementwiseTrinary!(
modeC = collect(Cint, Cinds)
modeD = modeC
cutensorElementwiseTrinary(handle(),
T[alpha], A, descA, modeA,
T[beta], B, descB, modeB,
T[gamma], C, descC, modeC,
Ref{T}(alpha), A, descA, modeA,
Ref{T}(beta), B, descB, modeB,
Ref{T}(gamma), C, descC, modeC,
D, descD, modeD,
opAB, opABC, T, stream)
return D
Expand All @@ -120,8 +120,8 @@ function elementwiseBinary!(
modeC = collect(Cint, Cinds)
modeD = modeC
cutensorElementwiseBinary(handle(),
T[alpha], A, descA, modeA,
T[gamma], C, descC, modeC,
Ref{T}(alpha), A, descA, modeA,
Ref{T}(gamma), C, descC, modeC,
D, descD, modeD,
opAC, T, stream)
return D
Expand All @@ -144,8 +144,8 @@ function elementwiseBinary!(
modeC = collect(Cint, Cinds)
modeD = modeC
cutensorElementwiseBinary(handle(),
T[alpha], A, descA, modeA,
T[gamma], C, descC, modeC,
Ref{T}(alpha), A, descA, modeA,
Ref{T}(gamma), C, descC, modeC,
D, descD, modeD,
opAC, T, stream)
return D
Expand All @@ -162,8 +162,8 @@ function elementwiseBinary!(
@assert size(C) == size(D) && strides(C) == strides(D)
descD = descC # must currently be identical
cutensorElementwiseBinary(handle(),
T[alpha], A.data, descA, A.inds,
T[gamma], C.data, descC, C.inds,
Ref{T}(alpha), A.data, descA, A.inds,
Ref{T}(gamma), C.data, descC, C.inds,
D.data, descD, C.inds,
opAC, T, stream)
return D
Expand All @@ -177,7 +177,7 @@ function permutation!(alpha::Number, A::CuArray, Ainds::ModeType,
T = eltype(B)
modeA = collect(Cint, Ainds)
modeB = collect(Cint, Binds)
cutensorPermutation(handle(), T[alpha], A, descA, modeA, B, descB, modeB, T,
cutensorPermutation(handle(), Ref{T}(alpha), A, descA, modeA, B, descB, modeB, T,
stream)
return B
end
Expand All @@ -189,7 +189,7 @@ function permutation!(alpha::Number, A::Array, Ainds::ModeType,
T = eltype(B)
modeA = collect(Cint, Ainds)
modeB = collect(Cint, Binds)
cutensorPermutation(handle(), T[alpha], A, descA, modeA, B, descB, modeB, T,
cutensorPermutation(handle(), Ref{T}(alpha), A, descA, modeA, B, descB, modeB, T,
stream)
return B
end
Expand Down Expand Up @@ -324,8 +324,8 @@ function reduction!(
out(Ref{UInt64}(C_NULL)))
)[] workspace->begin
cutensorReduction(handle(),
T[alpha], A, descA, modeA,
T[beta], C, descC, modeC,
Ref{T}(alpha), A, descA, modeA,
Ref{T}(beta), C, descC, modeC,
C, descC, modeC,
opReduce, typeCompute,
workspace, sizeof(workspace), stream)
Expand Down

0 comments on commit 100ed9f

Please sign in to comment.