diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index c29f1c73a4..f025dd65ed 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -40,8 +40,7 @@ end adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) = ROCArray(collect(x)) adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x)) -adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = - AMDGPU.rocRAND.default_rng() +adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error(""" Cannot map RNG of type $(typeof(x)) to AMDGPU. diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index fde8103bbb..385f0e713f 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -86,6 +86,13 @@ end @test parent(Flux.gpu(g3)) isa ROCMatrix{Float32} end +@testset "cpu and gpu on RNGs" begin + crng = Random.default_rng() + grng = gpu(rng) + @test grng isa AMDGPU.rocRAND.RNG + @test cpu(grng) === crng +end + @testset "Flux.onecold gpu" begin y = Flux.onehotbatch(ones(3), 1:10) |> Flux.gpu l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']