Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme rules for cuBLASLt #148

Closed
wsmoses opened this issue Sep 2, 2024 · 9 comments · Fixed by #151 or #159
Closed

Enzyme rules for cuBLASLt #148

wsmoses opened this issue Sep 2, 2024 · 9 comments · Fixed by #151 or #159

Comments

@wsmoses
Copy link

wsmoses commented Sep 2, 2024

function fused_dense!(

I imagine it would be fairly similar to https://github.com/FluxML/NNlib.jl/blob/master/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl

@avik-pal let me know how I can help here.

x/ref JuliaGPU/CUDA.jl#2478

@avik-pal
Copy link
Member

avik-pal commented Sep 2, 2024

Yeah I think it makes sense to add rules for this.

However in the long term wouldn't it be better to add these to enzyme proper similar to cuBLAS rules? cuBLASLt is an official NVIDIA extension to GEMM API and supports fusing activations and bias into the GEMM kernel, see https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api. FWIW from debug logs, I think even cuBLAS calls into cuBLASLt when gemm is called.

@avik-pal avik-pal changed the title Custom Rule for fused_dense! Enzyme rules for cuBLASLt Sep 2, 2024
@wsmoses
Copy link
Author

wsmoses commented Sep 2, 2024

Eventually perhaps, but it still may make sense to have it here — especially for CPU and other backends

@avik-pal
Copy link
Member

avik-pal commented Sep 3, 2024

Minimal Reproducer

using CUDA, LuxLib, Enzyme, NNlib

function fused_dense!(y, act, weight, x, b)
    op = LuxLib.internal_operation_mode((y, weight, x, b))
    LuxLib.Impl.fused_dense!(y, op, act, weight, x, b)
end

# CPU case
y = zeros(Float32, 2, 2)
weight = rand(Float32, 2, 2)
x = rand(Float32, 2, 2)
b = rand(Float32, 2)

fused_dense!(y, gelu, weight, x, b)

dy = rand(Float32, 2, 2)
dweight = zeros(Float32, 2, 2)
dx = zeros(Float32, 2, 2)
db = zeros(Float32, 2)

Enzyme.autodiff(
    Reverse, fused_dense!, Duplicated(y, dy), Const(gelu), Duplicated(weight, dweight),
    Duplicated(x, dx), Duplicated(b, db)) # Works

# GPU case
y = zeros(Float32, 2, 2) |> cu
weight = rand(Float32, 2, 2) |> cu
x = rand(Float32, 2, 2) |> cu
b = rand(Float32, 2) |> cu

fused_dense!(y, gelu, weight, x, b)

dy = rand(Float32, 2, 2) |> cu
dweight = zeros(Float32, 2, 2) |> cu
dx = zeros(Float32, 2, 2) |> cu
db = zeros(Float32, 2) |> cu

Enzyme.autodiff(
    Reverse, fused_dense!, Duplicated(y, dy), Const(gelu), Duplicated(weight, dweight),
    Duplicated(x, dx), Duplicated(b, db)) # Fails

@avik-pal
Copy link
Member

avik-pal commented Sep 9, 2024

Missed a couple of dispatches here. matmuladd! (called from fused_dense not fused_dense!) also calls into cuBLASLt

using Lux, Random, LuxCUDA, Enzyme

gdev = gpu_device()

rng = Random.default_rng()

model = Chain(Dense(2 => 3, tanh), Dense(3 => 2))
ps, st = Lux.setup(Random.default_rng(), model) |> gdev
x = rand(rng, Float32, 2, 10) |> gdev

y = first(model(x, ps, st))

function loss_function(y, model, ps, st, x)
    y .= first(model(x, ps, st))
    return
end

begin
    y = zeros(Float32, 2, 10) |> gdev
    dy = ones(Float32, 2, 10) |> gdev
    dx = zeros(Float32, 2, 10) |> gdev
    dps = Enzyme.make_zero(ps)

    Enzyme.autodiff(Reverse, loss_function, Const, Duplicated(y, dy),
        Const(model), Duplicated(ps, dps), Const(st), Duplicated(x, dx))

    @show dx
    @show dps
end

@avik-pal
Copy link
Member

@wsmoses regression in 0.13 even with the latest CUDA release

https://buildkite.com/julialang/luxlib-dot-jl/builds/1552#0192361b-c14e-4dd2-a7fa-22e41f9b46ef/340-903

 call fastcc void @julia_active_state_97656({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }* noalias nocapture nofree noundef nonnull writeonly sret({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }) align 8 dereferenceable(48) %3), !dbg !2688
define private fastcc void @julia_active_state_97656({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }* noalias nocapture nofree noundef nonnull writeonly sret({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }) align 8 dereferenceable(48) %0) unnamed_addr #649 !dbg !80049 {
top:
  %1 = call {}*** @julia.get_pgcstack()
  %2 = call {}*** @julia.get_pgcstack()
  %ptls_field6 = getelementptr inbounds {}**, {}*** %2, i64 2
  %3 = bitcast {}*** %ptls_field6 to i64***
  %ptls_load78 = load i64**, i64*** %3, align 8, !tbaa !2681
  %4 = getelementptr inbounds i64*, i64** %ptls_load78, i64 2
  %safepoint = load i64*, i64** %4, align 8, !tbaa !2685
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint), !dbg !80050
  fence syncscope("singlethread") seq_cst
  %5 = call fastcc nonnull dereferenceable(48) {} addrspace(10)* @julia_task_local_state__96760(), !dbg !80051
  %6 = addrspacecast {} addrspace(10)* %5 to [1 x i32] addrspace(11)*, !dbg !80052
  %memcpy_refined_src = getelementptr inbounds [1 x i32], [1 x i32] addrspace(11)* %6, i64 0, i64 0, !dbg !80052
  %7 = load i32, i32 addrspace(11)* %memcpy_refined_src, align 8, !dbg !80052, !tbaa !2728, !alias.scope !2732, !noalias !2733
  %8 = addrspacecast {} addrspace(10)* %5 to i8 addrspace(11)*, !dbg !80052
  %9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !80052
  %aggregate_load_box2.sroa.0.0..sroa_idx = bitcast i8 addrspace(11)* %9 to i64 addrspace(11)*, !dbg !80052
  %aggregate_load_box2.sroa.0.0.copyload = load i64, i64 addrspace(11)* %aggregate_load_box2.sroa.0.0..sroa_idx, align 1, !dbg !80052, !tbaa !2782, !alias.scope !2803, !noalias !80055
  %aggregate_load_box2.sroa.2.0..sroa_idx4 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 16, !dbg !80052
  %10 = bitcast i8 addrspace(11)* %aggregate_load_box2.sroa.2.0..sroa_idx4 to i64 addrspace(11)*, !dbg !80052
  %aggregate_load_box2.sroa.2.0.copyload = load i64, i64 addrspace(11)* %10, align 1, !dbg !80052, !tbaa !2782, !alias.scope !2803, !noalias !80055
  %11 = call fastcc nonnull {} addrspace(10)* @julia_stream_97660({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(48) %5), !dbg !80054
  %12 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 32, !dbg !80052
  %13 = bitcast i8 addrspace(11)* %12 to i32 addrspace(11)*, !dbg !80052
  %14 = load i32, i32 addrspace(11)* %13, align 8, !dbg !80052, !tbaa !2728, !alias.scope !2732, !noalias !2733
  %getfield_addr = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 40, !dbg !80052
  %15 = bitcast i8 addrspace(11)* %getfield_addr to {} addrspace(10)* addrspace(11)*, !dbg !80052
  %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %15 unordered, align 8, !dbg !80052, !tbaa !2728, !alias.scope !2732, !noalias !2733, !nonnull !0
  %unbox.fca.0.insert = insertvalue [1 x i32] poison, i32 %7, 0, !dbg !80058
  %16 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } zeroinitializer, [1 x i32] %unbox.fca.0.insert, 0, !dbg !80058
  %unbox3.fca.0.insert = insertvalue [2 x i64] poison, i64 %aggregate_load_box2.sroa.0.0.copyload, 0, !dbg !80058
  %unbox3.fca.1.insert = insertvalue [2 x i64] %unbox3.fca.0.insert, i64 %aggregate_load_box2.sroa.2.0.copyload, 1, !dbg !80058
  %17 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %16, [2 x i64] %unbox3.fca.1.insert, 1, !dbg !80058
  %18 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %17, {} addrspace(10)* %11, 2, !dbg !80058
  %19 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %18, i32 %14, 3, !dbg !80058
  %20 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %19, {} addrspace(10)* %getfield, 4, !dbg !80058
  store { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %20, { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }* %0, align 8, !dbg !80054, !noalias !80060
  ret void, !dbg !80054
}

@avik-pal avik-pal reopened this Sep 28, 2024
@wsmoses
Copy link
Author

wsmoses commented Sep 28, 2024

This I think is fixed on main which will be released shortly

@wsmoses
Copy link
Author

wsmoses commented Sep 28, 2024

patch just dropped, closing

@wsmoses wsmoses closed this as completed Sep 28, 2024
@avik-pal
Copy link
Member

@wsmoses
Copy link
Author

wsmoses commented Sep 28, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants