Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible way to implement a LoopVectorization extension for conv2d & meanpool2d & activations #540

Open
wants to merge 37 commits into
base: master
Choose a base branch
from

Conversation

jonas208
Copy link

@jonas208 jonas208 commented Sep 26, 2023

Some tests have shown that standard for-loop implementations for conv using LoopVectorization (LV) can be faster than the standard NNlib's im2col + gemm approach. There has already been some talks about a possible LV extension on Slack (and in one of the ML community meetings). For example, meanpooling can also be implemented using LV, but the small boost in performance is probably negligible in large networks (pooling is not as performance critical as conv I think). For conv I found out that:

  • Acceleration is usually greatest when the inputs have a large spatial size and few channels.
  • Using stride > 1, dilation > 1 or groups > 1 can slow down things a bit. (there are multiple impls for specialized cases)
  • The current state of LV ∇conv_filter! isn't really faster than the
    original implementation in some situations, for that reason, I left it out for the moment. (The same applies to the backwardpass of meanpool.)
  • Important: For inputs with very many channels (and a relatively small spatial size), the LV-Impl is often slower!

Tests were run on a Ryzen 9 5900X (on Windows 10).

I also wrote some tests (not sure if they fit properly in the NNlib's test strategy).
I hope the code is somewhat usable and helpful.

@ToucheSir ToucheSir added the benchmark Run automated benchmarks on CI label Sep 27, 2023
@fluxml-benchmark-bot
Copy link

Judge result

Benchmark Report for /home/runner/work/FluxMLBenchmarks.jl/FluxMLBenchmarks.jl/benchmark/script/..

Job Properties

  • Time of benchmarks:
    • Target: 27 Sep 2023 - 00:39
    • Baseline: 27 Sep 2023 - 00:38
  • Package commits:
    • Target: non gi
    • Baseline: non gi
  • Julia commits:
    • Target: bed2cd
    • Baseline: bed2cd
  • Julia command flags:
    • Target: None
    • Baseline: None
  • Environment variables:
    • Target: FLUXML_BENCHMARK_FLUX_MLP => true FLUXML_BENCHMARK_FLUX => true JULIA_NUM_THREADS => 1
    • Baseline: FLUXML_BENCHMARK_FLUX_MLP => true FLUXML_BENCHMARK_FLUX => true JULIA_NUM_THREADS => 1

Results

A ratio greater than 1.0 denotes a possible regression (marked with ❌), while a ratio less
than 1.0 denotes a possible improvement (marked with ✅). Only significant results - results
that indicate possible regressions or improvements - are shown below (thus, an empty table means that all
benchmark results remained invariant between builds).

ID time ratio memory ratio
["flux", "mlp", "Float32"] 9754.02 (5%) ❌ 47703.65 (1%) ❌
["flux", "mlp", "Float64"] 2.00 (5%) ❌ 1.00 (1%)

Benchmark Group List

Here's a list of all the benchmark groups executed by this job:

  • ["flux", "mlp"]

Julia versioninfo

Target

Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
      Ubuntu 22.04.3 LTS
  uname: Linux 6.2.0-1011-azure #11~22.04.1-Ubuntu SMP Wed Aug 23 19:26:19 UTC 2023 x86_64 x86_64
  CPU: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz: 
              speed         user         nice          sys         idle          irq
       #1  2294 MHz       2134 s          0 s        264 s       2108 s          0 s
       #2  2294 MHz       2492 s          0 s        291 s       1698 s          0 s
  Memory: 6.759757995605469 GB (5210.56640625 MB free)
  Uptime: 458.14 sec
  Load Avg:  1.17  1.03  0.55
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, broadwell)
  Threads: 1 on 2 virtual cores

Baseline

Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
      Ubuntu 22.04.3 LTS
  uname: Linux 6.2.0-1011-azure #11~22.04.1-Ubuntu SMP Wed Aug 23 19:26:19 UTC 2023 x86_64 x86_64
  CPU: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz: 
              speed         user         nice          sys         idle          irq
       #1  2294 MHz       1796 s          0 s        233 s       1819 s          0 s
       #2  2294 MHz       2197 s          0 s        259 s       1367 s          0 s
  Memory: 6.759757995605469 GB (5358.67578125 MB free)
  Uptime: 392.15 sec
  Load Avg:  1.26  1.02  0.51
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, broadwell)
  Threads: 1 on 2 virtual cores

Target result

Benchmark Report for /home/runner/work/FluxMLBenchmarks.jl/FluxMLBenchmarks.jl/benchmark/script/..

Job Properties

  • Time of benchmark: 27 Sep 2023 - 0:39
  • Package commit: non gi
  • Julia commit: bed2cd
  • Julia command flags: None
  • Environment variables: FLUXML_BENCHMARK_FLUX_MLP => true FLUXML_BENCHMARK_FLUX => true JULIA_NUM_THREADS => 1

Results

Below is a table of this job's results, obtained by running the benchmarks.
The values listed in the ID column have the structure [parent_group, child_group, ..., key], and can be used to
index into the BaseBenchmarks suite to retrieve the corresponding benchmarks.
The percentages accompanying time and memory values in the below table are noise tolerances. The "true"
time/memory value for a given benchmark is expected to fall within this percentage of the reported value.
An empty cell means that the value was zero.

ID time GC time memory allocations
["flux", "mlp", "Float16"] 160.600 μs (5%) 23.30 KiB (1%) 8
["flux", "mlp", "Float32"] 1.529 s (5%) 56.245 ms 151.40 MiB (1%) 2297238
["flux", "mlp", "Float64"] 342.000 μs (5%) 23.30 KiB (1%) 8

Benchmark Group List

Here's a list of all the benchmark groups executed by this job:

  • ["flux", "mlp"]

Julia versioninfo

Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
      Ubuntu 22.04.3 LTS
  uname: Linux 6.2.0-1011-azure #11~22.04.1-Ubuntu SMP Wed Aug 23 19:26:19 UTC 2023 x86_64 x86_64
  CPU: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz: 
              speed         user         nice          sys         idle          irq
       #1  2294 MHz       2134 s          0 s        264 s       2108 s          0 s
       #2  2294 MHz       2492 s          0 s        291 s       1698 s          0 s
  Memory: 6.759757995605469 GB (5210.56640625 MB free)
  Uptime: 458.14 sec
  Load Avg:  1.17  1.03  0.55
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, broadwell)
  Threads: 1 on 2 virtual cores

Baseline result

Benchmark Report for /home/runner/work/FluxMLBenchmarks.jl/FluxMLBenchmarks.jl/benchmark/script/..

Job Properties

  • Time of benchmark: 27 Sep 2023 - 0:38
  • Package commit: non gi
  • Julia commit: bed2cd
  • Julia command flags: None
  • Environment variables: FLUXML_BENCHMARK_FLUX_MLP => true FLUXML_BENCHMARK_FLUX => true JULIA_NUM_THREADS => 1

Results

Below is a table of this job's results, obtained by running the benchmarks.
The values listed in the ID column have the structure [parent_group, child_group, ..., key], and can be used to
index into the BaseBenchmarks suite to retrieve the corresponding benchmarks.
The percentages accompanying time and memory values in the below table are noise tolerances. The "true"
time/memory value for a given benchmark is expected to fall within this percentage of the reported value.
An empty cell means that the value was zero.

ID time GC time memory allocations
["flux", "mlp", "Float16"] 157.599 μs (5%) 23.30 KiB (1%) 8
["flux", "mlp", "Float32"] 156.800 μs (5%) 3.25 KiB (1%) 6
["flux", "mlp", "Float64"] 170.700 μs (5%) 23.30 KiB (1%) 8

Benchmark Group List

Here's a list of all the benchmark groups executed by this job:

  • ["flux", "mlp"]

Julia versioninfo

Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
      Ubuntu 22.04.3 LTS
  uname: Linux 6.2.0-1011-azure #11~22.04.1-Ubuntu SMP Wed Aug 23 19:26:19 UTC 2023 x86_64 x86_64
  CPU: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz: 
              speed         user         nice          sys         idle          irq
       #1  2294 MHz       1796 s          0 s        233 s       1819 s          0 s
       #2  2294 MHz       2197 s          0 s        259 s       1367 s          0 s
  Memory: 6.759757995605469 GB (5358.67578125 MB free)
  Uptime: 392.15 sec
  Load Avg:  1.26  1.02  0.51
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, broadwell)
  Threads: 1 on 2 virtual cores

Runtime information

Runtime Info
BLAS #threads 1
BLAS.vendor() lbt
Sys.CPU_THREADS 2

lscpu output:

Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             2
On-line CPU(s) list:                0,1
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 1
Core(s) per socket:                 2
Socket(s):                          1
Stepping:                           4
BogoMIPS:                           4190.18
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap clflushopt avx512cd avx512bw avx512vl xsaveopt xsavec xsaves md_clear
Hypervisor vendor:                  Microsoft
Virtualization type:                full
L1d cache:                          64 KiB (2 instances)
L1i cache:                          64 KiB (2 instances)
L2 cache:                           2 MiB (2 instances)
L3 cache:                           35.8 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0,1
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT Host state unknown
Cpu Property Value
Brand Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz
Vendor :Intel
Architecture :Skylake
Model Family: 0x06, Model: 0x55, Stepping: 0x04, Type: 0x00
Cores 2 physical cores, 2 logical cores (on executing CPU)
No Hyperthreading hardware capability detected
Clock Frequencies Not supported by CPU
Data Cache Level 1:3 : (32, 1024, 36608) kbytes
64 byte cache line size
Address Size 48 bits virtual, 46 bits physical
SIMD 512 bit = 64 byte max. SIMD vector size
Time Stamp Counter TSC is accessible via rdtsc
TSC increased at every clock cycle (non-invariant TSC)
Perf. Monitoring Performance Monitoring Counters (PMC) are not supported
Hypervisor Yes, Microsoft

@jonas208
Copy link
Author

Was this script run as well? https://github.com/FluxML/FluxMLBenchmarks.jl/blob/main/benchmark/benchmark/nnlib/conv.jl
I‘m not sure, but it seems that only [flux] [mlp] was run (or is conv part of this group as well?).

@ToucheSir
Copy link
Member

It should've run, I've asked folks more familiar with the benchmarking why it didn't. For now, your local benchmarking results are fine.

@jonas208
Copy link
Author

jonas208 commented Sep 27, 2023

Just updated /test/ext_loopvectorizationruntests.jl to run some benchmarks (on CI probably with just one thread I guess), to make some of the test results more precise with actual numbers, these are some local results from the modified script:

without LoopVectorization
  153.013 ms (171 allocations: 651.41 MiB) # conv fwd stride=1, dil=1, pad=0, groups=1
  169.605 ms (173 allocations: 369.11 MiB) # bwd

  83.041 ms (170 allocations: 319.79 MiB) # conv fwd stride=2, dil=2, pad=0, groups=1
  89.114 ms (172 allocations: 199.92 MiB) # bwd

  63.385 ms (431 allocations: 325.76 MiB) # conv fwd stride=2, dil=2, pad=2, groups=3
  86.451 ms (431 allocations: 202.99 MiB) # bwd

  8.591 ms (13 allocations: 35.61 MiB) # pool fwd stride=1, dil=1, pad=0
  2.694 ms (13 allocations: 1.81 MiB) # pool fwd stride=5, dil=2, pad=2
with LoopVectorization
  42.939 ms (41 allocations: 319.05 MiB)
  18.518 ms (2164 allocations: 36.86 MiB)

  37.051 ms (40 allocations: 156.63 MiB)
  14.204 ms (2096 allocations: 36.87 MiB)

  65.591 ms (40 allocations: 196.93 MiB)
  85.284 ms (9990 allocations: 111.64 MiB)

  5.962 ms (4 allocations: 35.61 MiB)
  810.100 μs (4 allocations: 1.80 MiB)
Test Summary:         | Pass  Total     Time
Convolution & Pooling |    3      3  3m48.3s

@jonas208
Copy link
Author

Has anyone an idea why the results on some CI devices are correct but sometimes totally wrong? Some weird LV behavior (the version should be the same on all devices)?
@ToucheSir Maybe @chriselrod has an idea?

@chriselrod
Copy link

chriselrod commented Sep 27, 2023

Note that @tturbo and @turbo may really benefit from knowing the sizes at compile time.
I had thought that this was part of the ConvDims type, but that appears not to be the case.
Probably not worth compiling a bunch of extra convolution types, but could be worth a try.

Note also that SimpleChains.jl defines a lot of layer types, as well as a ChainRules.rrule to make SimpleChains objects themselves usable as layers in other networks. Defining the ChainRules makes them Zygote-compatible (i.e., usable by Zygote, they do not support using Zygote).
https://github.com/PumasAI/SimpleChains.jl/blob/87a5b400d798bd38dfde8ab93cce959a2b7d3ce3/test/runtests.jl#L116-L121

Maybe @chriselrod has an idea?

Bugs?
Totally wrong is worrisome.
LV itself has a lot of tests for convolution examples.

Would be good if someone has a minimal example of being totally wrong, plus the CPU model/arch that was wrong.

Comment on lines 57 to 62
m = y_out + (y_stride - 1) * (y_out - 1)
n = x_out + (x_stride - 1) * (x_out - 1)
value = zero(T)
for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width
y_in = m + (y_w - 1) * y_dilation
x_in = n + (x_w - 1) * x_dilation

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it makes a difference, but LV's handling of indices written within an index expression might be better.
It has separate code for building indices by parsing index expressions vs parsing the loop body.
I'd have to look at the details to see the degree to which they're different. In theory, they should do the same thing.

LV would definitely benefit from the special case of x_dilation being 1.
Might be worth branching over (also, special case -1 or some other some other common small factors).

@jonas208
Copy link
Author

@chriselrod Thank you for the quick reply! Just added a minimal example in the runtests.jl script: https://github.com/jonas208/NNlib.jl/blob/lv-ext2/test/ext_loopvectorization/minimal_test.jl
To be really sure that the implementation is not just unequal to NNlib because of some errors in the logic, the minimal example compares two exact implementations with the difference that the first uses @turbo and the other doesn't.

1):

for index_batch in 1:current_batch_size # normally @threads is used here
        @turbo for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width
            for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width
                input_gradient[x_out + x_w - 1, y_out + y_w - 1, in_channel, index_batch] += weight[x_w, y_w, in_channel, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch]
            end
        end
end

2):

for index_batch in 1:current_batch_size # normally @threads is used here
        for out_channel in 1:out_channels, y_out in 1:output_height, x_out in 1:output_width
            for in_channel in 1:in_channels, y_w in 1:weight_height, x_w in 1:weight_width
                input_gradient[x_out + x_w - 1, y_out + y_w - 1, in_channel, index_batch] += weight[x_w, y_w, in_channel, out_channel] * output_gradient[x_out, y_out, out_channel, index_batch]
            end
        end
end

On my local machine, I never got unequal results (using Julia 1.9.3 on Windows 10, running on an AMD Ryzen 9 5900X). It seems that the results are also always correct on CI devices running MacOS (maybe the MacOS device always runs on the same CPU).
On CI Windows or CI Ubuntu, results are sometimes wrong and sometimes not. I also added cpuinfo() from CpuId.jl to get some information about the used CPUs. After rerunning the tests, maybe some regularities in the issues can be found.

@jonas208
Copy link
Author

Some logs:
Worked on

| Brand              | Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz               |
| Vendor             | :Intel                                                  |
| Architecture       | :Broadwell                                              |

Worked on

| Brand              | Intel(R) Xeon(R) CPU E5-1650 v2 @ 3.50GHz                  |
| Vendor             | :Intel                                                     |
| Architecture       | :IvyBridge                                                 |

Failed on

| Brand              | Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz           |
| Vendor             | :Intel                                                  |
| Architecture       | :UnknownIntel                                           |

Failed on

| Brand              | Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz          |
| Vendor             | :Intel                                                  |
| Architecture       | :Skylake                                                |

@jonas208
Copy link
Author

Oh right, thank you, I must have apparently no longer thought about it. Thanks for linking the article!

@jonas208
Copy link
Author

jonas208 commented Oct 8, 2023

After a lot of benchmarking and lots of smaller attempts to raise the performance a bit more, I mainly found out that:

  • for inference with a batch size of 1, LV conv clearly outperforms im2col (up to 10 times faster in MobileNetv3 or EfficientNetv2)
  • for inference with batch size of approx. >16, LV conv usually brings a slight performance advantage in bigger models (around 25%)
  • LV conv allocates (much) less memory (in particular when no padding is needed)

I also implemented LV conv without the need to do padding naively (using calc_padding_regions and bounds checks), but LV failed to process these loops sometimes (especially then when stride>1 or dilation>1).

  • for LV ∇conv_data!, only the most specialized case is left over, stride>1 or dilation>1 made the index calculation too complicated
  • in some cases, LV ∇conv_data! is slower, sometimes it's faster, so I'm not quite sure if we should keep this then?
  • im2col struggles with EfficientNet(v2) (maybe because of the high group counts and 1x1 convs)
  • setting the Julia threads to 1 can sometimes speed up im2col convolution when groups are used (threading over groups can result in an insane amount of allocations, e.g. >1M for im2col against a few thousands with LV per inference run in EfficientNet)

For benchmarking with whole models, I wrote this script which can compare inference/backward times for a customizable set of models and batch sizes.
The output is a .csv table with the timings and the acceleration factor. To compare Flux with PyTorch, this python script can be useful.
These were my results (benchmarked on a Ryzen 9 5900X):

It looks like only the CI on buildkite is run on this PR anymore.

This PR should now be ready for reviewing :)
Btw, does anyone have further ideas to improve performance (e.g. memory blocking)?
@NNlibMaintainers, e.g. @ToucheSir @mcabbott @CarloLucibello

@ToucheSir
Copy link
Member

If it's not too much trouble, could you try setting LinearAlgebra.BLAS.set_num_threads(1) and seeing if that affects the im2col results? Presently, we way oversubscribe on CPU because both BLAS and NNlib/Julia are managing separate threadpools.

@jonas208
Copy link
Author

jonas208 commented Oct 9, 2023

@ToucheSir Sure, got these results.

@jonas208
Copy link
Author

jonas208 commented Oct 9, 2023

Just had a quick look at the results on buildkite on the AMD GPU server. I saw massive (very unusual) differences between LV (75ms) and im2col (15+seconds) (never occurred otherwise, e.g. not even on the CUDA GPU server). Does anyone know what could be reasons for this? https://buildkite.com/julialang/nnlib-dot-jl/builds/1116#018b13d7-e2fe-469f-8b5a-ac06f25b0bee/286-568
Starting from the 4th row at "with LV", every two rows measurements occur again, which used im2col (because from then on stride>1, means that im2col is used again for backward pass). The last 8 lines are timings for the activations.

@jonas208 jonas208 changed the title Possible way to implement a LoopVectorization extension for conv2d & meanpool2d Possible way to implement a LoopVectorization extension for conv2d & meanpool2d & activations Oct 10, 2023
@mcabbott
Copy link
Member

A quick bug report from trying this out on this example: FluxML/Flux.jl#2350 fails:

(jl_fgGTBI) pkg> add https://github.com/jonas208/NNlib.jl#lv-ext2  ^C

julia> using NNlib, LoopVectorization

julia> function train_loop(model, optimizer, train_loader, test_loader; epochs=5)
           for epoch  1:epochs
               iter = tqdm(train_loader)
               total = 0
               corrects = 0
               for (X, Y)  iter
                   grads = Flux.gradient(model) do m
                       predicted = m(X)
                       ignore() do 
                           b_size = size(X)[end]
                           corrects += sum(onecold(predicted, 0:9) .== onecold(Y, 0:9))
                           total += b_size
                       end
                       logitcrossentropy(predicted, Y)
                   end
                   optimizer, model = Flux.Optimise.update!(optimizer, model, grads[1])
                   set_postfix(iter, accuracy=corrects / total)
               end

               val_accuracy = accuracy(model, test_loader)
               @info "Epoch $epoch/5 | Accuracy : $val_accuracy"
           end
       end
train_loop (generic function with 1 method)

julia> train_loop(model, optimizer, TRAIN_LOADER, TEST_LOADER)
0.0%┣                                                                  ┫ 0/469 [00:00<00:00, -0s/it]
ERROR: conversion to pointer not defined for Array{Float32, 4}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] unsafe_convert(::Type{Ptr{Float32}}, a::Array{Float32, 4})
    @ Base ./pointer.jl:67
  [3] unsafe_convert
    @ ~/.julia/packages/OffsetArrays/0MOrf/src/OffsetArrays.jl:464 [inlined]
  [4] pointer(x::OffsetArrays.OffsetArray{Float32, 4, Array{Float32, 4}})
    @ Base ./abstractarray.jl:1253
  [5] memory_reference
    @ ~/.julia/packages/LayoutPointers/qGkBo/src/stridedpointers.jl:21 [inlined]
  [6] memory_reference
    @ ~/.julia/packages/LayoutPointers/qGkBo/src/stridedpointers.jl:18 [inlined]
  [7] stridedpointer_preserve
    @ ~/.julia/packages/LayoutPointers/qGkBo/src/stridedpointers.jl:100 [inlined]
  [8] ∇conv_data!(input_gradient::Array{…}, output_gradient::Array{…}, weight::Array{…}, cdims::DenseConvDims{…})
    @ NNlibLoopVectorizationExt ~/.julia/packages/NNlib/HV0kw/ext/NNlibLoopVectorizationExt/conv.jl:160
  [9] #∇conv_data#241
    @ ~/.julia/packages/NNlib/HV0kw/src/conv.jl:99 [inlined]
 [10] ∇conv_data
    @ ~/.julia/packages/NNlib/HV0kw/src/conv.jl:95 [inlined]
 [11] #380
    @ ~/.julia/packages/NNlib/HV0kw/src/conv.jl:350 [inlined]
 [12] unthunk
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
 [13] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
 [14] map
    @ ./tuple.jl:283 [inlined]
 [15] map
    @ ./tuple.jl:284 [inlined]
 [16] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
 [17] ZBack
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
 [18] Conv
    @ ~/.julia/packages/Flux/u7QSl/src/layers/conv.jl:202 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [20] macro expansion
    @ ~/.julia/packages/Flux/u7QSl/src/layers/basic.jl:53 [inlined]
 [21] _applychain
    @ ~/.julia/packages/Flux/u7QSl/src/layers/basic.jl:53 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [23] Chain
    @ ~/.julia/packages/Flux/u7QSl/src/layers/basic.jl:51 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [25] #25
    @ ./REPL[53]:8 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
 [28] gradient(f::Function, args::Chain{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
 [29] train_loop(model::Chain{…}, optimizer::@NamedTuple{}, train_loader::DataLoader{…}, test_loader::DataLoader{…}; epochs::Int64)
    @ Main ./REPL[53]:7
 [30] train_loop(model::Chain{…}, optimizer::@NamedTuple{}, train_loader::DataLoader{…}, test_loader::DataLoader{…})
    @ Main ./REPL[53]:1
 [31] top-level scope
    @ REPL[54]:1
Some type information was truncated. Use `show(err)` to see complete types.

(jl_fgGTBI) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_fgGTBI/Project.toml`
  [bdcacae8] LoopVectorization v0.12.165
  [872c559c] NNlib v0.9.7 `https://github.com/jonas208/NNlib.jl#lv-ext2`
  
julia> versioninfo()
Julia Version 1.11.0-DEV.773
Commit 855cd5662d* (2023-10-29 15:45 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.6.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
  Threads: 5 on 4 virtual cores
Environment:
  JULIA_NUM_THREADS = 4

@chriselrod
Copy link

chriselrod commented Oct 29, 2023

That's because you're on Julia master, which removed Base.unsafe_convert(::Type{Ptr{T}}, x::Array{T}).

I guess I should've been calling pointer instead.

@mcabbott
Copy link
Member

Ah thanks, I didn't know. Piratically defining Base.unsafe_convert(::Type{Ptr{T}}, x::Array{T}) where T = pointer(x), it now runs, and @time train_loop(model, optimizer, TRAIN_LOADER, TEST_LOADER; epochs=1) reports 48s with this PR, 34s without. At least on this machine, M1 with apple blas, haven't investigated further.

One minor comment is that Flux.nil numbers as used by @autosize seem to trigger warnings (on the first run):

julia> Flux.@autosize (28,28,1,1) Chain(MeanPool((2,2)), Flux.flatten, Dense(_ => 10))
┌ Warning: #= /Users/me/.julia/packages/NNlib/HV0kw/ext/NNlibLoopVectorizationExt/pooling.jl:57 =#:
│ `LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
│ Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.
└ @ NNlibLoopVectorizationExt ~/.julia/packages/LoopVectorization/xHfLl/src/condense_loopset.jl:1148
Chain(
  MeanPool((2, 2)),
  Flux.flatten,
  Dense(196 => 10),                     # 1_970 parameters
) 

@chriselrod
Copy link

chriselrod commented Oct 30, 2023

That message is @warn maxlog=1, so it should generally only display once.
It means (as it says) @turbo is not actually being used, even on repeat runs -- it just isn't repeating the message.

@mcabbott
Copy link
Member

Yes I understand about the logging. But I think it's a sign that e.g. NNlib.meanpool!(output::Array{T,4}, ...) where {T<:Real} here is too wide a signature. Shouldn't it restrict T to the list of types which will work? Leave the existing ::AbstractArray{<:Any, 4} method as the fallback.

@jonas208
Copy link
Author

@mcabbott (and maybe @chriselrod if you're interested) Sorry for the late reply! I just saw it again now.

Shouldn't it restrict T to the list of types which will work? Leave the existing ::AbstractArray{<:Any, 4} method as the fallback.

Yes, thanks, this would be better. I will probably change this in the next days.

reports 48s with this PR, 34s without

Due to the ratio between the large number of channels compared to the low spatial size of 28x28, it's somewhat expected that this example isn't optimal for the current LV implementation. On the other hand, no stride/dilation/groups are used, so the most optimized path is chosen. So in general, it shouldn't be significantly slower than im2col in such cases (probably at least on x86 AVX2).

On my Ryzen processor, the LV impl took ~33s and the default impl took ~38s with 24 Julia threads and 12 BLAS-threads or ~30s with 24 Julia threads and 1 BLAS-thread. I made the experience that turning off BLAS-threads can often speed up the default im2col approach, but sometimes it can also make it worse (especially the backward passes).
At the moment, however, I can't think of any other ways to further optimize the LV method. Are there any? If possible, it may be worth trying to include oneDNN as an accelerator package (similar to NNPACK, but better and more stable).

Apart from that, if you have time, it would be interesting to see some more benchmark results on the M1, for example the tests in the benchmark script which produces this csv. table to compare LV against im2col.

@maxfreu
Copy link
Contributor

maxfreu commented Dec 14, 2023

Just a side note: The pytorch timings are probably sub-optimal. One would use torch.compile() nowadays and with torch.no_grad() for inference, to actually suppress the computation of gradients. None that torch.compile() then requires to benchmark the second run, like in julia. So ultimately, the bar is even higher...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark Run automated benchmarks on CI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants