Skip to content

Commit

Permalink
Replace StaticArrays with a simple immutable array type (re-land #83)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jul 14, 2022
1 parent 1088dd0 commit 23ad275
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 32 deletions.
2 changes: 0 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ steps:
cap: "recent"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
soft_fail:
- exit_status: 1

- label: "Julia nightly"
plugins:
Expand Down
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
CUDA = "3.5"
ForwardDiff = "0.10"
KernelAbstractions = "0.7, 0.8"
LLVM = "3, 4"
StaticArrays = "0.12, 1"
julia = "1.6"
1 change: 1 addition & 0 deletions src/GemmKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("tiling.jl")

include("config.jl")
include("epilogue.jl")
include("array.jl")
include("kernel.jl")
include("layout.jl")
include("operator.jl")
Expand Down
43 changes: 43 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# a simple immutable array type backed by stack memory
#
# similar to StaticArrays, but immutable to prevent optimization bugs (JuliaLang/julia#41800)

struct LocalArray{S <: Tuple, T, N, L} <: AbstractArray{T,N}
data::NTuple{L,T}

LocalArray{S,T,N,L}(::UndefInitializer) where {S,T,N,L} = new{S,T,N,L}()
LocalArray{S,T,N,L}(data::NTuple{L,T}) where {S,T,N,L} = new{S,T,N,L}(data)
end

@inline @generated function LocalArray{S,T}(args...) where {S,T}
dims = (S.parameters...,)
N = length(dims)
L = prod(dims)
@assert isbitstype(T)
quote
LocalArray{S, T, $N, $L}(args...)
end
end

# array interface
Base.IndexStyle(::Type{<:LocalArray}) = IndexLinear()
Base.size(x::LocalArray{S}) where {S} = (S.parameters...,)

# indexing
Base.@propagate_inbounds function Base.getindex(v::LocalArray, i::Int)
@boundscheck checkbounds(v,i)
@inbounds v.data[i]
end
Base.@propagate_inbounds function Base.setindex(v::LocalArray{S,T,N,L} , val, i::Int) where {S,T,N,L}
@boundscheck checkbounds(v,i)
new_data = Base.setindex(v.data, val, i)
LocalArray{S,T,N,L}(new_data)
end
## XXX: Base's setindex doesn't have a ND version
Base.@propagate_inbounds function Base.setindex(v::LocalArray{S,T,N,L} , val, is::Int...) where {S,T,N,L}
@boundscheck checkbounds(v,is...)
I = CartesianIndex(is...)
i = LinearIndices(v)[I]
new_data = Base.setindex(v.data, val, i)
LocalArray{S,T,N,L}(new_data)
end
55 changes: 28 additions & 27 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ module Kernel
using CUDA
using GemmKernels
using GemmKernels.Tiling
using GemmKernels: LocalArray
using KernelAbstractions.Extras: @unroll
using StaticArrays
using Base: setindex

function matmul_singlestage(a, b, c, d,
transf_gl2sh_a, transf_gl2sh_b, transf_gl2sh_c, transf_sh2gl_d,
Expand Down Expand Up @@ -42,12 +43,12 @@ function matmul_singlestage(a, b, c, d,
# (2) Load a compute_warp.M x compute_warp.N tile of C from shared memory into registers
warp_tile = subdivide(block_tile.MN, Tile(conf.compute_warp).MN, warpId, conf.warps_per_block)

c_frags = MArray{Tuple{num_fragments_m, num_fragments_n}, Operator.fragtype_accum(conf.operator, conf.shared_c_layout)}(undef)
c_frags = LocalArray{Tuple{num_fragments_m, num_fragments_n}, Operator.fragtype_accum(conf.operator, conf.shared_c_layout)}(undef)

@unroll for i = 1 : num_fragments_m
@unroll for j = 1 : num_fragments_n
tile = translate_offset(warp_tile, (M = (i-1)*conf.compute_op_shape.M, N = (j-1)*conf.compute_op_shape.N))
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(conf.operator, conf.shared_c_layout, shmem_c, tile), tile)
@inbounds c_frags = setindex(c_frags, transf_sh2rf_c(Operator.load_c(conf.operator, conf.shared_c_layout, shmem_c, tile), tile), i ,j)
end
end

Expand Down Expand Up @@ -83,25 +84,25 @@ function matmul_singlestage(a, b, c, d,
# (3.3) Calculate a compute_warp.M x compute_warp.N tile of D, using a compute_warp.M x compute_warp.N x compute_warp.K operation
@unroll for warp_tile = parallellise(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)
# (3.3.1) Load a compute_warp.M x compute_warp.K tile of A from shared memory into registers
a_frags = MArray{Tuple{num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)
a_frags = LocalArray{Tuple{num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)

@unroll for i = 1 : num_fragments_m
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*conf.compute_op_shape.M, K = 0))
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile)
@inbounds a_frags = setindex(a_frags, transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile), i)
end

# (3.3.2) Load a compute_warp.K x compute_warp.N tile of B from shared memory into registers
b_frags = MArray{Tuple{num_fragments_n}, Operator.fragtype_b(conf.operator, conf.shared_b_layout)}(undef)
b_frags = LocalArray{Tuple{num_fragments_n}, Operator.fragtype_b(conf.operator, conf.shared_b_layout)}(undef)

@unroll for j = 1 : num_fragments_n
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*conf.compute_op_shape.N))
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile)
@inbounds b_frags = setindex(b_frags, transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile), j)
end

# (3.3.3) Compute a compute_warp.M x compute_warp.N x compute_warp.K matrix product within one warp
@unroll for i = 1 : num_fragments_m
@unroll for j = 1 : num_fragments_n
@inbounds c_frags[i, j] = Operator.mma(conf.operator, a_frags[i], b_frags[j], c_frags[i, j])
@inbounds c_frags = setindex(c_frags, Operator.mma(conf.operator, a_frags[i], b_frags[j], c_frags[i, j]), i, j)
end
end
end
Expand All @@ -118,7 +119,7 @@ function matmul_singlestage(a, b, c, d,
@unroll for i = 1 : num_fragments_m
@unroll for j = 1 : num_fragments_n
tile = translate_offset(warp_tile, (M = (i-1)*conf.compute_op_shape.M, N = (j-1)*conf.compute_op_shape.N))
Operator.store_d(conf.operator, conf.shared_d_layout, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
@inbounds Operator.store_d(conf.operator, conf.shared_d_layout, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
end
end

Expand Down Expand Up @@ -165,12 +166,12 @@ function matmul_pipelined(a, b, c, d,
# (2) Load a compute_warp.M x compute_warp.N tile of C from shared memory into registers
warp_tile = subdivide(block_tile.MN, Tile(conf.compute_warp).MN, warpId, conf.warps_per_block)

c_frags = MArray{Tuple{num_fragments_m, num_fragments_n}, Operator.fragtype_accum(conf.operator, conf.shared_c_layout)}(undef)
c_frags = LocalArray{Tuple{num_fragments_m, num_fragments_n}, Operator.fragtype_accum(conf.operator, conf.shared_c_layout)}(undef)

@unroll for i = 1 : num_fragments_m
@unroll for j = 1 : num_fragments_n
tile = translate_offset(warp_tile, (M = (i-1)*conf.compute_op_shape.M, N = (j-1)*conf.compute_op_shape.N))
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(conf.operator, conf.shared_c_layout, shmem_c, tile), tile)
@inbounds c_frags = setindex(c_frags, transf_sh2rf_c(Operator.load_c(conf.operator, conf.shared_c_layout, shmem_c, tile), tile), i, j)
end
end

Expand All @@ -187,24 +188,24 @@ function matmul_pipelined(a, b, c, d,
b_frag_i = (block_tile.size.K * block_tile.size.N) ÷ (conf.mem_b_warp.K * conf.mem_b_warp.N * conf.warps_per_block)
b_frag_j = (conf.mem_b_warp.K * conf.mem_b_warp.N) ÷ (conf.mem_b_thread.K * conf.mem_b_thread.N * 32)

a_fragment = MArray{Tuple{a_frag_i, a_frag_j}, Layout.fragtype(conf.global_a_layout, conf.mem_a_thread)}(undef)
b_fragment = MArray{Tuple{b_frag_i, b_frag_j}, Layout.fragtype(conf.global_b_layout, conf.mem_b_thread)}(undef)
a_fragment = LocalArray{Tuple{a_frag_i, a_frag_j}, Layout.fragtype(conf.global_a_layout, conf.mem_a_thread)}(undef)
b_fragment = LocalArray{Tuple{b_frag_i, b_frag_j}, Layout.fragtype(conf.global_b_layout, conf.mem_b_thread)}(undef)

a_frags = MArray{Tuple{2, num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)
b_frags = MArray{Tuple{2, num_fragments_n}, Operator.fragtype_b(conf.operator, conf.shared_b_layout)}(undef)
a_frags = LocalArray{Tuple{2, num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)
b_frags = LocalArray{Tuple{2, num_fragments_n}, Operator.fragtype_b(conf.operator, conf.shared_b_layout)}(undef)

warp_tile_mn = subdivide(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)

# ld.global(0 : block_shape.K)
@unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@inbounds a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = 0)))
@inbounds a_fragment = setindex(a_fragment, Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = 0))), i, j)
end
end

@unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@inbounds b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = 0, N = block_j)))
@inbounds b_fragment = setindex(b_fragment, Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = 0, N = block_j))), i, j)
end
end

Expand All @@ -230,24 +231,24 @@ function matmul_pipelined(a, b, c, d,

@unroll for i = 1 : num_fragments_m
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*conf.compute_op_shape.M, K = 0))
@inbounds a_frags[1, i] = transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile)
@inbounds a_frags = setindex(a_frags, transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile), 1, i)
end

@unroll for j = 1 : num_fragments_n
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*conf.compute_op_shape.N))
@inbounds b_frags[1, j] = transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile)
@inbounds b_frags = setindex(b_frags, transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile), 1, j)
end

# ld.global(block_shape.K : 2 * block_shape.K)
@unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@inbounds a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_tile.size.K)))
@inbounds a_fragment = setindex(a_fragment, Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_tile.size.K))), i, j)
end
end

@unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@inbounds b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_tile.size.K, N = block_j)))
@inbounds b_fragment = setindex(b_fragment, Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_tile.size.K, N = block_j))), i, j)
end
end

Expand Down Expand Up @@ -281,13 +282,13 @@ function matmul_pipelined(a, b, c, d,
# ld.global(block_k + 2 * block_shape.K : block_k + 3 * block_shape.K)
@unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@inbounds a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k + 2 * block_tile.size.K)))
@inbounds a_fragment = setindex(a_fragment, Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k + 2 * block_tile.size.K))), i, j)
end
end

@unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@inbounds b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k + 2 * block_tile.size.K, N = block_j)))
@inbounds b_fragment = setindex(b_fragment, Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k + 2 * block_tile.size.K, N = block_j))), i, j)
end
end
end
Expand All @@ -298,18 +299,18 @@ function matmul_pipelined(a, b, c, d,

@unroll for i = 1 : num_fragments_m
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*conf.compute_op_shape.M, K = 0))
@inbounds a_frags[nxt_stage, i] = transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile)
@inbounds a_frags = setindex(a_frags, transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile), nxt_stage, i)
end

@unroll for j = 1 : num_fragments_n
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*conf.compute_op_shape.N))
@inbounds b_frags[nxt_stage, j] = transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile)
@inbounds b_frags = setindex(b_frags, transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile), nxt_stage, j)
end

# mma(cur_stage)
@unroll for i = 1 : num_fragments_m
@unroll for j = 1 : num_fragments_n
@inbounds c_frags[i, j] = Operator.mma(conf.operator, a_frags[cur_stage, i], b_frags[cur_stage, j], c_frags[i, j])
@inbounds c_frags = setindex(c_frags, Operator.mma(conf.operator, a_frags[cur_stage, i], b_frags[cur_stage, j], c_frags[i, j]), i, j)
end
end
end
Expand All @@ -325,7 +326,7 @@ function matmul_pipelined(a, b, c, d,
@unroll for i = 1 : num_fragments_m
@unroll for j = 1 : num_fragments_n
tile = translate_offset(warp_tile, (M = (i-1)*conf.compute_op_shape.M, N = (j-1)*conf.compute_op_shape.N))
Operator.store_d(conf.operator, conf.shared_d_layout, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
@inbounds Operator.store_d(conf.operator, conf.shared_d_layout, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
end
end

Expand Down
1 change: 0 additions & 1 deletion src/layout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ module Layout
using CUDA
using KernelAbstractions.Extras: @unroll
using GemmKernels.Tiling
using StaticArrays

# ---------------------
# Customise computation
Expand Down

0 comments on commit 23ad275

Please sign in to comment.