diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index ff37ac278f..de94621009 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -176,7 +176,11 @@ end n_max_threads::Integer, ) (Nij, _, _, _, Nh) = DataLayouts.universal_size(us) - Nh_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nh) + Nh_thread = min( + Int(fld(n_max_threads, Nij * Nij)), + maximum_allowable_threads()[3], + Nh, + ) Nh_blocks = cld(Nh, Nh_thread) @assert prod((Nij, Nij, Nh_thread)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nh_thread))),$n_max_threads)" return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))