From 36399682d018a185916db483157457d91d68cf7d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 11:45:02 -0500 Subject: [PATCH] fix: pass in RNG to shuffle --- lib/WeightInitializers/Project.toml | 2 +- lib/WeightInitializers/src/WeightInitializers.jl | 2 +- lib/WeightInitializers/src/initializers.jl | 14 +++++++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index bb39b7955..4f22301a3 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.4" +version = "1.0.5" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl index 6702f3fec..831f638a2 100644 --- a/lib/WeightInitializers/src/WeightInitializers.jl +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -3,7 +3,7 @@ module WeightInitializers using ArgCheck: @argcheck using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr -using Random: Random, AbstractRNG, shuffle +using Random: Random, AbstractRNG using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0 using Statistics: Statistics, std diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl index 81de6a17c..5cc4efef1 100644 --- a/lib/WeightInitializers/src/initializers.jl +++ b/lib/WeightInitializers/src/initializers.jl @@ -225,7 +225,19 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparse_array .*= T(std) fill!(view(sparse_array, 1:num_zeros, :), zero(T)) - return @allowscalar mapslices(shuffle, sparse_array; dims=1) + if applicable(Random.rng_native_52, rng) + @inbounds for i in axes(sparse_array, 2) + @allowscalar Random.shuffle!(rng, view(sparse_array, :, i)) + end + else + @warn "`rng` is not supported by `Random.shuffle!`. Ignoring the `rng` for \ + shuffle." maxlog=1 + @inbounds for i in axes(sparse_array, 2) + @allowscalar Random.shuffle!(view(sparse_array, :, i)) + end + end + + return sparse_array end """