-
Notifications
You must be signed in to change notification settings - Fork 0
Enzyme rules for cuBLASLt
#148
Comments
Yeah I think it makes sense to add rules for this. However in the long term wouldn't it be better to add these to enzyme proper similar to cuBLAS rules? cuBLASLt is an official NVIDIA extension to GEMM API and supports fusing activations and bias into the GEMM kernel, see https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api. FWIW from debug logs, I think even cuBLAS calls into cuBLASLt when |
Eventually perhaps, but it still may make sense to have it here — especially for CPU and other backends |
Minimal Reproducer using CUDA, LuxLib, Enzyme, NNlib
function fused_dense!(y, act, weight, x, b)
op = LuxLib.internal_operation_mode((y, weight, x, b))
LuxLib.Impl.fused_dense!(y, op, act, weight, x, b)
end
# CPU case
y = zeros(Float32, 2, 2)
weight = rand(Float32, 2, 2)
x = rand(Float32, 2, 2)
b = rand(Float32, 2)
fused_dense!(y, gelu, weight, x, b)
dy = rand(Float32, 2, 2)
dweight = zeros(Float32, 2, 2)
dx = zeros(Float32, 2, 2)
db = zeros(Float32, 2)
Enzyme.autodiff(
Reverse, fused_dense!, Duplicated(y, dy), Const(gelu), Duplicated(weight, dweight),
Duplicated(x, dx), Duplicated(b, db)) # Works
# GPU case
y = zeros(Float32, 2, 2) |> cu
weight = rand(Float32, 2, 2) |> cu
x = rand(Float32, 2, 2) |> cu
b = rand(Float32, 2) |> cu
fused_dense!(y, gelu, weight, x, b)
dy = rand(Float32, 2, 2) |> cu
dweight = zeros(Float32, 2, 2) |> cu
dx = zeros(Float32, 2, 2) |> cu
db = zeros(Float32, 2) |> cu
Enzyme.autodiff(
Reverse, fused_dense!, Duplicated(y, dy), Const(gelu), Duplicated(weight, dweight),
Duplicated(x, dx), Duplicated(b, db)) # Fails |
Missed a couple of dispatches here. using Lux, Random, LuxCUDA, Enzyme
gdev = gpu_device()
rng = Random.default_rng()
model = Chain(Dense(2 => 3, tanh), Dense(3 => 2))
ps, st = Lux.setup(Random.default_rng(), model) |> gdev
x = rand(rng, Float32, 2, 10) |> gdev
y = first(model(x, ps, st))
function loss_function(y, model, ps, st, x)
y .= first(model(x, ps, st))
return
end
begin
y = zeros(Float32, 2, 10) |> gdev
dy = ones(Float32, 2, 10) |> gdev
dx = zeros(Float32, 2, 10) |> gdev
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(Reverse, loss_function, Const, Duplicated(y, dy),
Const(model), Duplicated(ps, dps), Const(st), Duplicated(x, dx))
@show dx
@show dps
end |
@wsmoses regression in 0.13 even with the latest CUDA release call fastcc void @julia_active_state_97656({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }* noalias nocapture nofree noundef nonnull writeonly sret({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }) align 8 dereferenceable(48) %3), !dbg !2688
define private fastcc void @julia_active_state_97656({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }* noalias nocapture nofree noundef nonnull writeonly sret({ [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }) align 8 dereferenceable(48) %0) unnamed_addr #649 !dbg !80049 {
top:
%1 = call {}*** @julia.get_pgcstack()
%2 = call {}*** @julia.get_pgcstack()
%ptls_field6 = getelementptr inbounds {}**, {}*** %2, i64 2
%3 = bitcast {}*** %ptls_field6 to i64***
%ptls_load78 = load i64**, i64*** %3, align 8, !tbaa !2681
%4 = getelementptr inbounds i64*, i64** %ptls_load78, i64 2
%safepoint = load i64*, i64** %4, align 8, !tbaa !2685
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint), !dbg !80050
fence syncscope("singlethread") seq_cst
%5 = call fastcc nonnull dereferenceable(48) {} addrspace(10)* @julia_task_local_state__96760(), !dbg !80051
%6 = addrspacecast {} addrspace(10)* %5 to [1 x i32] addrspace(11)*, !dbg !80052
%memcpy_refined_src = getelementptr inbounds [1 x i32], [1 x i32] addrspace(11)* %6, i64 0, i64 0, !dbg !80052
%7 = load i32, i32 addrspace(11)* %memcpy_refined_src, align 8, !dbg !80052, !tbaa !2728, !alias.scope !2732, !noalias !2733
%8 = addrspacecast {} addrspace(10)* %5 to i8 addrspace(11)*, !dbg !80052
%9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !80052
%aggregate_load_box2.sroa.0.0..sroa_idx = bitcast i8 addrspace(11)* %9 to i64 addrspace(11)*, !dbg !80052
%aggregate_load_box2.sroa.0.0.copyload = load i64, i64 addrspace(11)* %aggregate_load_box2.sroa.0.0..sroa_idx, align 1, !dbg !80052, !tbaa !2782, !alias.scope !2803, !noalias !80055
%aggregate_load_box2.sroa.2.0..sroa_idx4 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 16, !dbg !80052
%10 = bitcast i8 addrspace(11)* %aggregate_load_box2.sroa.2.0..sroa_idx4 to i64 addrspace(11)*, !dbg !80052
%aggregate_load_box2.sroa.2.0.copyload = load i64, i64 addrspace(11)* %10, align 1, !dbg !80052, !tbaa !2782, !alias.scope !2803, !noalias !80055
%11 = call fastcc nonnull {} addrspace(10)* @julia_stream_97660({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(48) %5), !dbg !80054
%12 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 32, !dbg !80052
%13 = bitcast i8 addrspace(11)* %12 to i32 addrspace(11)*, !dbg !80052
%14 = load i32, i32 addrspace(11)* %13, align 8, !dbg !80052, !tbaa !2728, !alias.scope !2732, !noalias !2733
%getfield_addr = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 40, !dbg !80052
%15 = bitcast i8 addrspace(11)* %getfield_addr to {} addrspace(10)* addrspace(11)*, !dbg !80052
%getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %15 unordered, align 8, !dbg !80052, !tbaa !2728, !alias.scope !2732, !noalias !2733, !nonnull !0
%unbox.fca.0.insert = insertvalue [1 x i32] poison, i32 %7, 0, !dbg !80058
%16 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } zeroinitializer, [1 x i32] %unbox.fca.0.insert, 0, !dbg !80058
%unbox3.fca.0.insert = insertvalue [2 x i64] poison, i64 %aggregate_load_box2.sroa.0.0.copyload, 0, !dbg !80058
%unbox3.fca.1.insert = insertvalue [2 x i64] %unbox3.fca.0.insert, i64 %aggregate_load_box2.sroa.2.0.copyload, 1, !dbg !80058
%17 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %16, [2 x i64] %unbox3.fca.1.insert, 1, !dbg !80058
%18 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %17, {} addrspace(10)* %11, 2, !dbg !80058
%19 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %18, i32 %14, 3, !dbg !80058
%20 = insertvalue { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %19, {} addrspace(10)* %getfield, 4, !dbg !80058
store { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* } %20, { [1 x i32], [2 x i64], {} addrspace(10)*, i32, {} addrspace(10)* }* %0, align 8, !dbg !80054, !noalias !80060
ret void, !dbg !80054
} |
This I think is fixed on main which will be released shortly |
patch just dropped, closing |
Can you open an issue with a MWE?
…On Sat, Sep 28, 2024 at 11:05 AM Avik Pal ***@***.***> wrote:
Still persists
https://buildkite.com/julialang/luxlib-dot-jl/builds/1552#01923945-c4f0-49b3-9209-65a47a56b909/341-858
—
Reply to this email directly, view it on GitHub
<#148 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXEQRBLWXYOPMAESWGDZY3HVTAVCNFSM6AAAAABNPJT2COVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBQG4ZDAOBZGA>
.
You are receiving this because you modified the open/close state.Message
ID: ***@***.***>
|
LuxLib.jl/src/impl/dense.jl
Line 35 in ef784ed
I imagine it would be fairly similar to https://github.com/FluxML/NNlib.jl/blob/master/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl
@avik-pal let me know how I can help here.
x/ref JuliaGPU/CUDA.jl#2478
The text was updated successfully, but these errors were encountered: