diff --git a/lib/cutensor/wrappers.jl b/lib/cutensor/wrappers.jl index 68237bc960..e1b1aa99f0 100644 --- a/lib/cutensor/wrappers.jl +++ b/lib/cutensor/wrappers.jl @@ -65,9 +65,9 @@ function elementwiseTrinary!( modeC = collect(Cint, Cinds) modeD = modeC cutensorElementwiseTrinary(handle(), - T[alpha], A, descA, modeA, - T[beta], B, descB, modeB, - T[gamma], C, descC, modeC, + Ref{T}(alpha), A, descA, modeA, + Ref{T}(beta), B, descB, modeB, + Ref{T}(gamma), C, descC, modeC, D, descD, modeD, opAB, opABC, T, stream) return D @@ -95,9 +95,9 @@ function elementwiseTrinary!( modeC = collect(Cint, Cinds) modeD = modeC cutensorElementwiseTrinary(handle(), - T[alpha], A, descA, modeA, - T[beta], B, descB, modeB, - T[gamma], C, descC, modeC, + Ref{T}(alpha), A, descA, modeA, + Ref{T}(beta), B, descB, modeB, + Ref{T}(gamma), C, descC, modeC, D, descD, modeD, opAB, opABC, T, stream) return D @@ -120,8 +120,8 @@ function elementwiseBinary!( modeC = collect(Cint, Cinds) modeD = modeC cutensorElementwiseBinary(handle(), - T[alpha], A, descA, modeA, - T[gamma], C, descC, modeC, + Ref{T}(alpha), A, descA, modeA, + Ref{T}(gamma), C, descC, modeC, D, descD, modeD, opAC, T, stream) return D @@ -144,8 +144,8 @@ function elementwiseBinary!( modeC = collect(Cint, Cinds) modeD = modeC cutensorElementwiseBinary(handle(), - T[alpha], A, descA, modeA, - T[gamma], C, descC, modeC, + Ref{T}(alpha), A, descA, modeA, + Ref{T}(gamma), C, descC, modeC, D, descD, modeD, opAC, T, stream) return D @@ -162,8 +162,8 @@ function elementwiseBinary!( @assert size(C) == size(D) && strides(C) == strides(D) descD = descC # must currently be identical cutensorElementwiseBinary(handle(), - T[alpha], A.data, descA, A.inds, - T[gamma], C.data, descC, C.inds, + Ref{T}(alpha), A.data, descA, A.inds, + Ref{T}(gamma), C.data, descC, C.inds, D.data, descD, C.inds, opAC, T, stream) return D @@ -177,7 +177,7 @@ function permutation!(alpha::Number, A::CuArray, Ainds::ModeType, T = eltype(B) modeA = collect(Cint, Ainds) modeB = collect(Cint, Binds) - cutensorPermutation(handle(), T[alpha], A, descA, modeA, B, descB, modeB, T, + cutensorPermutation(handle(), Ref{T}(alpha), A, descA, modeA, B, descB, modeB, T, stream) return B end @@ -189,7 +189,7 @@ function permutation!(alpha::Number, A::Array, Ainds::ModeType, T = eltype(B) modeA = collect(Cint, Ainds) modeB = collect(Cint, Binds) - cutensorPermutation(handle(), T[alpha], A, descA, modeA, B, descB, modeB, T, + cutensorPermutation(handle(), Ref{T}(alpha), A, descA, modeA, B, descB, modeB, T, stream) return B end @@ -324,8 +324,8 @@ function reduction!( out(Ref{UInt64}(C_NULL))) )[] workspace->begin cutensorReduction(handle(), - T[alpha], A, descA, modeA, - T[beta], C, descC, modeC, + Ref{T}(alpha), A, descA, modeA, + Ref{T}(beta), C, descC, modeC, C, descC, modeC, opReduce, typeCompute, workspace, sizeof(workspace), stream)