diff --git a/src/layout.jl b/src/layout.jl index 336b7d34..2dbe3c03 100644 --- a/src/layout.jl +++ b/src/layout.jl @@ -4,6 +4,7 @@ module Layout using CUDA using LLVMLoopInfo: @loopinfo using GemmKernels.Tiling +using Base.Cartesian: @ntuple # --------------------- # Customise computation @@ -17,7 +18,8 @@ using GemmKernels.Tiling struct Vec{N, T} end -@inline @generated function vloada(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, i::Integer = 1) where {N, T, AS} +@inline @generated function vloada(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, + i::Integer = 1) where {N, T, AS} alignment = sizeof(T) * N return quote @@ -26,13 +28,23 @@ struct Vec{N, T} end end end -@inline @generated function vstorea!(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, x, i::Integer = 1) where {N, T, AS} +@inline @generated function vstorea!(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, + x::NTuple{M,<:Any}, i::Integer = 1) where {N, T, AS, M} alignment = sizeof(T) * N - return quote - vec_ptr = Base.bitcast(Core.LLVMPtr{NTuple{N, VecElement{T}}, AS}, ptr) - return unsafe_store!(vec_ptr, x, (i-1) ÷ N + 1, Val($alignment)) + ex = quote end + + # we may be storing more values than we can using a single vectorized operation + # (e.g., when types mismatch, storing 8 Float16s in a Float32 shared memory layout) + for offset = 0:N:M-1 + append!(ex.args, (quote + y = @ntuple $N j -> VecElement{T}(x[j+$offset].value) + vec_ptr = Base.bitcast(Core.LLVMPtr{NTuple{N, VecElement{T}}, AS}, ptr) + unsafe_store!(vec_ptr, y, (i+$offset-1) ÷ N + 1, Val($alignment)) + end).args) end + + return ex end # -----------