Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add layouts for accessing unaligned or non tile-sized global. #130

Merged
merged 5 commits into from
Jul 5, 2023

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Jul 4, 2023

Step towards making the BLAS API more generally usable.

Performance is worse, of course. We could try to do better, e.g. doing vectorized loads for most of the input, but that becomes complicated quickly so let's do the simple thing first.

4096 x 4096

julia> @benchmark CUDA.@sync GemmKernels.BLAS.gemmEx!('N', 'N', true, A, B, false, C)
BenchmarkTools.Trial: 5342 samples with 1 evaluation.
 Range (min … max):  813.384 μs …   7.196 ms  ┊ GC (min … max): 0.00% … 78.88%
 Time  (median):     931.268 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   932.315 μs ± 121.528 μs  ┊ GC (mean ± σ):  0.23% ±  1.54%

                                                    ▁▄▆▆█▅▂
  ▂▂▂▂▂▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▂▁▁▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▃▅▆███████▇▄▃▃ ▃
  813 μs           Histogram: frequency by time          948 μs <

4095 x 4095

julia> @benchmark CUDA.@sync GemmKernels.BLAS.gemmEx!('N', 'N', true, A, B, false, C)
BenchmarkTools.Trial: 3555 samples with 1 evaluation.
 Range (min … max):  1.244 ms …   8.118 ms  ┊ GC (min … max): 0.00% … 74.44%
 Time  (median):     1.401 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.403 ms ± 114.594 μs  ┊ GC (mean ± σ):  0.12% ±  1.25%

                                                  ▃█▅▅▃
  ▂▂▁▁▂▂▁▁▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▂▁▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▃▄▆█████▇▆▄▃▃▂ ▃
  1.24 ms         Histogram: frequency by time        1.43 ms <

 Memory estimate: 17.58 KiB, allocs estimate: 514.

The code looks pretty bad, with LLVM not predicating the element-wise loads/stores but inserting loads of branches:

	setp.gt.u64 	%p387, %rd594, %rd140;
	@%p387 bra 	LBB0_678;
	{ .reg .b16 	%tmp_lo;
	  mov.b32 	{%tmp_lo, %h2044}, %hh125; }
	st.global.b16 	[%rd191+-14], %h2044;
LBB0_678:
	mov.u32 	%hh126, %r244;
	setp.gt.u64 	%p388, %rd595, %rd140;
	@%p388 bra 	LBB0_680;
	{ .reg .b16 	%tmp_hi;
	  mov.b32 	{%h2045, %tmp_hi}, %hh126; }
	st.global.b16 	[%rd191+-12], %h2045;
LBB0_680:
	setp.gt.u64 	%p389, %rd596, %rd140;
	@%p389 bra 	LBB0_682;
	{ .reg .b16 	%tmp_lo;
	  mov.b32 	{%tmp_lo, %h2046}, %hh126; }
	st.global.b16 	[%rd191+-10], %h2046;

I guess LLVM lacks a good if conversion? In any case, because the pattern is so clear, ptxas does detect it and generate predicated SASS:

        ISETP.GT.U32.AND P0, PT, R2, UR6, PT ;
        ISETP.GT.U32.AND.EX P4, PT, R20, UR7, PT, P4 ;
        ISETP.GT.U32.AND.EX P0, PT, R3, UR7, PT, P0 ;
   @!P4 STG.E.U16 [R6.64+-0x4], R11 ;
   @!P0 STG.E.U16 [R6.64+-0x2], R9 ;

Fixes #52

Use it to extend the BLAS API to all input sizes.
@maleadt
Copy link
Member Author

maleadt commented Jul 4, 2023

Benchmark results for commit 3fafafc (comparing to 033da06):

ID mean₁ mean₂ Δmin
["BLAS", "FPU", "Float32'*Float32'=Float32 (256×256×256, alpha)"] 114.220 μs ± 1.796 μs 80.826 μs ± 153.441 ns 27.7% ✅
["BLAS", "FPU", "Float32'*Float32=Float32 (256×256×256, alpha)"] 113.011 μs ± 223.516 ns 76.379 μs ± 148.552 ns 32.4% ✅
["BLAS", "FPU", "Float32*Float32'=Float32 (256×256×256, alpha)"] 112.001 μs ± 247.370 ns 77.959 μs ± 161.475 ns 30.5% ✅
["BLAS", "FPU", "Float32*Float32=Float32 (256×256×256, alpha)"] 116.009 μs ± 2.242 μs 79.385 μs ± 503.833 ns 29.5% ✅
["BLAS", "FPU", "Float32*Float32=Float32 (256×256×256, alpha, beta)"] 118.840 μs ± 2.203 μs 83.108 μs ± 504.933 ns 28.2% ✅
["BLAS", "FPU", "Float32*Float32=Float32 (256×256×256, beta)"] 125.610 μs ± 4.159 μs 90.338 μs ± 531.585 ns 25.9% ✅
["BLAS", "WMMA", "Float16'*Float16'=Float16 (256×256×256, alpha)"] 91.687 μs ± 240.173 ns 55.345 μs ± 3.727 μs 47.6% ✅
["BLAS", "WMMA", "Float16'*Float16'=Float16 (4096×4096×4096, alpha)"] 5.862 ms ± 4.084 μs 5.482 ms ± 669.187 μs 29.8% ✅
["BLAS", "WMMA", "Float16'*Float16'=Float32 (256×256×256, alpha)"] 77.963 μs ± 3.430 μs 44.076 μs ± 2.355 μs 45.8% ✅
["BLAS", "WMMA", "Float16'*Float16=Float16 (256×256×256, alpha)"] 93.272 μs ± 2.250 μs 51.176 μs ± 2.044 μs 47.6% ✅
["BLAS", "WMMA", "Float16'*Float16=Float32 (256×256×256, alpha)"] 79.090 μs ± 4.021 μs 43.802 μs ± 2.543 μs 45.9% ✅
["BLAS", "WMMA", "Float16*Float16'=Float16 (256×256×256, alpha)"] 93.384 μs ± 2.279 μs 53.597 μs ± 1.179 μs 42.7% ✅
["BLAS", "WMMA", "Float16*Float16'=Float16 (4096×4096×4096, alpha)"] 5.410 ms ± 2.762 μs 5.078 ms ± 603.052 μs 28.1% ✅
["BLAS", "WMMA", "Float16*Float16'=Float32 (256×256×256, alpha)"] 80.596 μs ± 2.224 μs 45.164 μs ± 137.152 ns 42.1% ✅
["BLAS", "WMMA", "Float16*Float16'=Float32 (4096×4096×4096, alpha)"] 6.113 ms ± 2.612 μs 5.635 ms ± 735.072 μs 28.8% ✅
["BLAS", "WMMA", "Float16*Float16=Float16 (256×256×256, alpha)"] 89.811 μs ± 5.068 μs 47.416 μs ± 462.692 ns 41.7% ✅
["BLAS", "WMMA", "Float16*Float16=Float16 (256×256×256, alpha, beta)"] 95.352 μs ± 1.894 μs 49.096 μs ± 448.378 ns 47.1% ✅
["BLAS", "WMMA", "Float16*Float16=Float16 (256×256×256, beta)"] 90.329 μs ± 1.873 μs 46.189 μs ± 513.800 ns 47.5% ✅
["BLAS", "WMMA", "Float16*Float16=Float32 (256×256×256, alpha)"] 79.431 μs ± 4.231 μs 39.048 μs ± 470.520 ns 46.2% ✅
["BLAS", "WMMA", "Float16*Float16=Float32 (256×256×256, alpha, beta)"] 81.168 μs ± 4.389 μs 41.197 μs ± 484.514 ns 44.4% ✅
["BLAS", "WMMA", "Float16*Float16=Float32 (256×256×256, beta)"] 83.236 μs ± 1.755 μs 41.303 μs ± 484.387 ns 49.2% ✅

@codecov
Copy link

codecov bot commented Jul 4, 2023

Codecov Report

Patch coverage: 53.15% and project coverage change: +0.52 🎉

Comparison is base (033da06) 30.67% compared to head (3fafafc) 31.20%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #130      +/-   ##
==========================================
+ Coverage   30.67%   31.20%   +0.52%     
==========================================
  Files          11       11              
  Lines         802      875      +73     
==========================================
+ Hits          246      273      +27     
- Misses        556      602      +46     
Impacted Files Coverage Δ
src/config.jl 82.10% <ø> (ø)
src/layout.jl 16.31% <10.90%> (-3.97%) ⬇️
src/operator.jl 11.58% <25.00%> (-0.15%) ⬇️
src/blas.jl 93.54% <100.00%> (+6.70%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@maleadt
Copy link
Member Author

maleadt commented Jul 4, 2023

The slowdown is entirely spent on the host. Let's add a simple cache.

@maleadt maleadt marked this pull request as draft July 4, 2023 14:29
@maleadt maleadt marked this pull request as ready for review July 4, 2023 20:18
@maleadt maleadt merged commit 92d26eb into master Jul 5, 2023
@maleadt maleadt deleted the tb/unsafe branch July 5, 2023 07:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Errors on small array inputs
1 participant