Skip to content

Commit

Permalink
Move the loop into vstorea.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jun 29, 2023
1 parent 609920f commit 27903f6
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/layout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ using Base.Cartesian: @ntuple

struct Vec{N, T} end

@inline @generated function vloada(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, i::Integer = 1) where {N, T, AS}
@inline @generated function vloada(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS},
i::Integer = 1) where {N, T, AS}
alignment = sizeof(T) * N

return quote
Expand All @@ -27,13 +28,20 @@ struct Vec{N, T} end
end
end

@inline @generated function vstorea!(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, x, i::Integer = 1) where {N, T, AS}
@inline @generated function vstorea!(::Type{Vec{N, T}}, ptr::Core.LLVMPtr{T, AS}, x,
i::Integer = 1) where {N, T, AS}
alignment = sizeof(T) * N

return quote
y = @ntuple $N i -> VecElement{T}(x[i].value)
vec_ptr = Base.bitcast(Core.LLVMPtr{NTuple{N, VecElement{T}}, AS}, ptr)
return unsafe_store!(vec_ptr, y, (i-1) ÷ N + 1, Val($alignment))
# we may be storing more values than we can using a single vectorized operation
# (e.g., when types mismatch, storing 8 Float16s in a Float32 shared memory layout)
@loopinfo unroll for offset = 1:$N:length(x)
y = @ntuple $N j -> VecElement{T}(x[j+offset-1].value)
vec_ptr = Base.bitcast(Core.LLVMPtr{NTuple{N, VecElement{T}}, AS}, ptr)
unsafe_store!(vec_ptr, y, (i+offset-2) ÷ N + 1, Val($alignment))
end

return
end
end

Expand Down Expand Up @@ -81,21 +89,13 @@ abstract type AlignedColMajor{T} <: LayoutBase{T} end
return vloada(Vec{N, T}, pointer(workspace), linear_base + linear_offset - 1)
end

@inline @generated function store!(::Type{<:AlignedColMajor{T}}, workspace, value, tile::Tile{size}) where {T, size}
@inline function store!(::Type{<:AlignedColMajor{T}}, workspace, value, tile::Tile{size}) where {T, size}
N = 16 ÷ sizeof(T)

quote
linear_base = linearise(tile.base, Base.size(workspace))
linear_offset = linearise(tile.offset, Base.size(workspace))
linear_base = linearise(tile.base, Base.size(workspace))
linear_offset = linearise(tile.offset, Base.size(workspace))

# we may be storing more values than we can using a single vectorized operation
# (e.g., when types mismatch, storing 8 Float16s in a Float32 shared memory layout)
@loopinfo unroll for value_offset = 1:$N:length(value)
x = @ntuple $N i -> (value[value_offset+i-1])
vstorea!(Vec{$N, T}, pointer(workspace), x,
linear_base + linear_offset + value_offset - 2)
end
end
vstorea!(Vec{N, T}, pointer(workspace), value, linear_base + linear_offset - 1)
end

# --------
Expand Down

0 comments on commit 27903f6

Please sign in to comment.