You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#83 (comment)
Blocked on ptxas bug, so will have to wait for CUDA 11.6 or so.
diff --git a/src/array.jl b/src/array.jl
index b4d3067..ab6b236 100644
--- a/src/array.jl+++ b/src/array.jl@@ -19,6 +19,15 @@ end
end
end
+@inline @generated function LocalArray{S}(data::NTuple{L,T}) where {S,T,L}+ dims = (S.parameters...,)+ N = length(dims)+ @assert L == prod(dims)+ quote+ LocalArray{S, T, $N, L}(data)+ end+end+
# array interface
Base.IndexStyle(::Type{<:LocalArray}) = IndexLinear()
Base.size(x::LocalArray{S}) where {S} = (S.parameters...,)
diff --git a/src/kernel.jl b/src/kernel.jl
index 5f12e51..2e1f426 100644
--- a/src/kernel.jl+++ b/src/kernel.jl@@ -84,12 +84,11 @@ function matmul_singlestage(a, b, c, d,
# (3.3) Calculate a compute_warp.M x compute_warp.N tile of D, using a compute_warp.M x compute_warp.N x compute_warp.K operation
@unroll for warp_tile = parallellise(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)
# (3.3.1) Load a compute_warp.M x compute_warp.K tile of A from shared memory into registers
- a_frags = LocalArray{Tuple{num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)-- @unroll for i = 1 : num_fragments_m+ a_frag_data = ntuple(Val(num_fragments_m)) do i
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*conf.compute_op_shape.M, K = 0))
- @inbounds a_frags = setindex(a_frags, transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile), i)+ transf_sh2rf_a(Operator.load_a(conf.operator, conf.shared_a_layout, shmem_a, a_tile), a_tile)
end
+ a_frags = LocalArray{Tuple{num_fragments_m}}(a_frag_data)
# (3.3.2) Load a compute_warp.K x compute_warp.N tile of B from shared memory into registers
b_frags = LocalArray{Tuple{num_fragments_n}, Operator.fragtype_b(conf.operator, conf.shared_b_layout)}(undef)
The text was updated successfully, but these errors were encountered:
#83 (comment)
Blocked on ptxas bug, so will have to wait for CUDA 11.6 or so.
The text was updated successfully, but these errors were encountered: