Skip to content

Commit

Permalink
overload mul! for gpu arryays (#167)
Browse files Browse the repository at this point in the history
* mul

* Update gpuarrays.jl

* Update gpu_tests.jl

* Update gpuarrays.jl

* Update gpuarrays.jl

* Update gpu_tests.jl

* Update gpu_tests.jl

* Update gpu_tests.jl

* Update gpu_tests.jl

* more test

* Update gpu_tests.jl

* Update gpu_tests.jl
  • Loading branch information
YichengDWu committed Oct 18, 2022
1 parent d96f434 commit 98720a9
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 0 deletions.
200 changes: 200 additions & 0 deletions src/compat/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax}
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}}

GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))

Expand Down Expand Up @@ -71,3 +74,200 @@ function ComponentArray(nt::NamedTuple{names,<:Tuple{Vararg{Union{GPUArrays.Abst
G = Base.typename(typeof(gpuarray)).wrapper # SciMLBase.parameterless_type(gpuarray)
return GPUArrays.adapt(G, ComponentArray(NamedTuple{names}(map(GPUArrays.adapt(Array{T}), nt))))
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, B::GPUComponentVecorMat,
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Number, b::Number)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end

function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::GPUComponentVecorMat,
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::GPUComponentVecorMat, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, B::GPUComponentVecorMat,
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real,
b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat},
a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
function LinearAlgebra.mul!(C::GPUComponentVecorMat,
A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
},
B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end
18 changes: 18 additions & 0 deletions test/gpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,22 @@ end
jlca3 = deepcopy(jlca)
@test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a=1:2, b=3:4))
end
@testset "mul!" begin
A = jlca .* jlca';
@test_nowarn mul!(deepcopy(A), A, A, 1, 2);
@test_nowarn mul!(deepcopy(A), A', A', 1, 2);
@test_nowarn mul!(deepcopy(A), A', A, 1, 2);
@test_nowarn mul!(deepcopy(A), A, A', 1, 2);
@test_nowarn mul!(deepcopy(A), A, getdata(A'), 1, 2);
@test_nowarn mul!(deepcopy(A), getdata(A'), A, 1, 2);
@test_nowarn mul!(deepcopy(A), getdata(A'), getdata(A'), 1, 2);
@test_nowarn mul!(deepcopy(A), transpose(A), A, 1, 2);
@test_nowarn mul!(deepcopy(A), A, transpose(A), 1, 2);
@test_nowarn mul!(deepcopy(A), transpose(A), transpose(A), 1, 2);
@test_nowarn mul!(deepcopy(A), transpose(getdata(A)), A, 1, 2);
@test_nowarn mul!(deepcopy(A), A, transpose(getdata(A)), 1, 2);
@test_nowarn mul!(deepcopy(A), transpose(getdata(A)), transpose(getdata(A)), 1, 2);
@test_nowarn mul!(deepcopy(A), transpose(A), A', 1, 2);
@test_nowarn mul!(deepcopy(A), A', transpose(A), 1, 2);
end
end

0 comments on commit 98720a9

Please sign in to comment.