From 367680bdee0368ed4c2e9f98eb914acf8117cc50 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 1 Jan 2025 09:58:38 -0500 Subject: [PATCH] fix: update default rng for reactant (#1152) * fix: update default rng for reactant * feat: handle RNGs in layers correctly --- Project.toml | 4 +-- lib/MLDataDevices/Project.toml | 4 +-- .../ext/MLDataDevicesReactantExt.jl | 11 +++--- lib/MLDataDevices/test/xla_tests.jl | 3 +- test/reactant/layer_tests.jl | 34 +++++++++++++++++++ 5 files changed, 46 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 227398b6d..a08546d73 100644 --- a/Project.toml +++ b/Project.toml @@ -100,7 +100,7 @@ LinearAlgebra = "1.10" LossFunctions = "0.11.1, 1" LuxCore = "1.2" LuxLib = "1.3.7" -MLDataDevices = "1.6" +MLDataDevices = "1.6.6" MLUtils = "0.4.4" MPI = "0.20.19" MacroTools = "0.5.13" @@ -110,7 +110,7 @@ NNlib = "0.9.26" Optimisers = "0.4.1" Preferences = "1.4.3" Random = "1.10" -Reactant = "0.2.12" +Reactant = "0.2.13" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2bc461363..87f9a2650 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.6.5" +version = "1.6.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -66,7 +66,7 @@ Metal = "1" OneHotArrays = "0.2.5" Preferences = "1.4" Random = "1.10" -Reactant = "0.2.6" +Reactant = "0.2.13" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index 4e55940d0..9cc1f082c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -2,22 +2,21 @@ module MLDataDevicesReactantExt using Adapt: Adapt using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type +using Random: Random using Reactant: Reactant, XLA, ConcreteRArray, ConcreteRNumber, TracedRArray, TracedRNumber MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true -# Default RNG: Forward to CPU, we will compile it +# Default RNG function MLDataDevices.default_device_rng(::ReactantDevice) - return MLDataDevices.default_device_rng(CPUDevice()) + return Reactant.TracedRandom.default_rng() end # Query Device from Array function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray}) - client = XLA.client(x.data) - device = XLA.device(x.data) - return ReactantDevice(client, device) + return ReactantDevice(XLA.client(x.data), XLA.device(x.data)) end function Internal.get_device(::Union{TracedRArray, TracedRNumber}) @@ -54,4 +53,6 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::ConcreteRArray) return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x)) end +Adapt.adapt_storage(::CPUDevice, ::Reactant.ConcreteRNG) = Random.default_rng() + end diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index 30377c828..bf39be0c7 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -39,7 +39,8 @@ using FillArrays, Zygote # Extensions device = reactant_device() aType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRArray : Array - rngType = Random.AbstractRNG + rngType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRNG : + Random.AbstractRNG ps_xpu = ps |> device @test get_device(ps_xpu) isa ReactantDevice diff --git a/test/reactant/layer_tests.jl b/test/reactant/layer_tests.jl index 8130691cb..e0e0fb526 100644 --- a/test/reactant/layer_tests.jl +++ b/test/reactant/layer_tests.jl @@ -63,3 +63,37 @@ end end end end + +@testitem "Dropout Layers" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin + using Reactant, Lux, Random + + @testset "$(mode)" for (mode, atype, dev, ongpu) in MODES + if mode == "amdgpu" + @warn "Skipping AMDGPU tests for Reactant" + continue + end + + dev = reactant_device(; force=true) + + if ongpu + Reactant.set_default_backend("gpu") + else + Reactant.set_default_backend("cpu") + end + + @testset for layer in (AlphaDropout, Dropout, VariationalHiddenDropout) + model = layer(0.5f0) + ps, st = Lux.setup(Random.default_rng(), model) |> dev + x = randn(Float32, 10, 10) |> dev + + @test st.rng isa Reactant.ConcreteRNG + + hlo = @code_hlo model(x, ps, st) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + y, st2 = @jit model(x, ps, st) + @test st2.rng isa Reactant.ConcreteRNG + @test st.rng.seed != st2.rng.seed + end + end +end