Skip to content

Commit

Permalink
fix cuda device
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 13, 2024
1 parent 13ddabd commit a2c92ba
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
8 changes: 4 additions & 4 deletions test/ext_amdgpu/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end

@testset "Convolution" begin
for conv_type in (Conv, ConvTranspose), nd in 1:3
m = conv_type(tuple(fill(2, nd)...), 3 => 4) |> f32
m = conv_type(tuple(fill(2, nd)...), 3 => 4)
x = rand(Float32, fill(10, nd)..., 3, 5)

md, xd = Flux.gpu.((m, x))
Expand All @@ -53,10 +53,10 @@ end
x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu

pad = ntuple(i -> i, nd)
m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu
m = conv_type(kernel, 3 => 4, pad=pad) |> gpu

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> gpu

@test size(m(x)) == size(m_expanded(x))
end
Expand Down Expand Up @@ -92,7 +92,7 @@ end
end

@testset "Cross-correlation" begin
m = CrossCor((2, 2), 3 => 4) |> f32
m = CrossCor((2, 2), 3 => 4)
x = rand(Float32, 5, 5, 3, 2)
test_gradients(m, x, test_gpu=true)
end
Expand Down
6 changes: 6 additions & 0 deletions test/ext_cuda/get_devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ for id in 0:(length(CUDA.devices()) - 1)
@test isequal(Flux.cpu(dense_model.weight), weight)
@test isequal(Flux.cpu(dense_model.bias), bias)
end

# gpu_device remembers the last device selected
# Therefore, we need to reset it to the current cuda device
@test gpu_device().device.handle == length(CUDA.devices()) - 1
gpu_device(CUDA.device().handle + 1)

# finally move to CPU, and see if things work
cdev = cpu_device()
dense_model = cdev(dense_model)
Expand Down
1 change: 0 additions & 1 deletion test/ext_cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ CUDA.allowscalar(false)
@testset "get_devices" begin
include("get_devices.jl")
end

@testset "cuda" begin
include("cuda.jl")
end
Expand Down

0 comments on commit a2c92ba

Please sign in to comment.