Skip to content

Commit

Permalink
fix: task switching in AMDGPU complex batched_matmul (#178)
Browse files Browse the repository at this point in the history
* ci(buildkite): add downstream testing for NeuralOperators

* perf: restore old batched_mul

* fix: disable threading for certain devices

* revert: "perf: restore old batched_mul"

This reverts commit a8c0f3b4615f96a8773577e16fac61ba310d8123.
  • Loading branch information
avik-pal authored Oct 25, 2024
1 parent ceb36a1 commit c63829b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
5 changes: 2 additions & 3 deletions lib/LuxLib/.buildkite/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ steps:
- src
- ext
env:
RETESTITEMS_NWORKERS: 2
BACKEND_GROUP: "AMDGPU"
agents:
queue: "juliagpu"
Expand Down Expand Up @@ -126,6 +125,7 @@ steps:
repo:
- "Boltz"
- "Lux"
- "NeuralOperators"

- group: ":telescope: Downstream AMD GPU"
steps:
Expand All @@ -143,15 +143,14 @@ steps:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
env:
RETESTITEMS_NWORKERS: 2
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main"
timeout_in_minutes: 240
matrix:
setup:
repo:
- "Boltz"
- "Lux"
- "NeuralOperators"

env:
JULIA_PKG_SERVER: ""
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.3.4"
version = "1.3.5"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
41 changes: 37 additions & 4 deletions lib/LuxLib/src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ end
function batched_matmul_loopvec_impl! end

function fallback_batched_matmul(
dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT}
opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT}
z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1),
size(y, 2), max(size(x, 3), size(y, 3)))
fallback_batched_matmul!(z, dev, x, y)
fallback_batched_matmul!(z, opmode, x, y)
return z
end

function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3},
z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
# XXX: bring back once the enzyme segfault is fixed
# @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
Expand All @@ -90,6 +90,36 @@ function fallback_batched_matmul!(
throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul."))
end

if use_threaded_batched_matmul(get_device_type(x))
unsafe_fallback_threaded_batched_matmul!(z, x, y)
else
unsafe_fallback_serial_batched_matmul!(z, x, y)
end

return
end

function unsafe_fallback_serial_batched_matmul!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
if size(x, 3) == size(y, 3)
for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, L))
end
elseif size(x, 3) == 1
for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, 1), batchview(y, L))
end
else # has to be size(y, 3) == 1
for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, 1))
end
end
end

function unsafe_fallback_threaded_batched_matmul!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
old_threads = maybe_reduce_BLAS_threads(z)

if size(x, 3) == size(y, 3)
Expand All @@ -107,10 +137,13 @@ function fallback_batched_matmul!(
end

reset_BLAS_threads(old_threads)

return
end

use_threaded_batched_matmul(::Type) = false
use_threaded_batched_matmul(::Type{CUDADevice}) = true
use_threaded_batched_matmul(::Type{CPUDevice}) = true

function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {xT, yT}
∇batched_matmul = @closure Δ_ -> begin
Expand Down

0 comments on commit c63829b

Please sign in to comment.