Skip to content

Commit

Permalink
feat: overlay all generators
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 17, 2024
1 parent 36e56cb commit 6c21721
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 49 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

Expand All @@ -35,6 +36,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantRandom123Ext = "Random123"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

Expand All @@ -52,6 +54,7 @@ NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Expand Down
11 changes: 11 additions & 0 deletions ext/ReactantRandom123Ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module ReactantRandom123Ext

using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
using Reactant: TracedRandom

TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"
TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"

end
103 changes: 103 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,29 @@ end
end

# random ops
"""
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from a uniform random
distribution between 0 and 1. Returns a NamedTuple with the following fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
Expand Down Expand Up @@ -1059,6 +1082,86 @@ end
return (; output_state, output)
end

"""
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from a standard normal
distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
scaled_uniform = subtract(
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
constant(fill(T(1), size(rand_uniform))),
)
probit = erf_inv(scaled_uniform)
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
return (; output_state=seed, output=rand_normal)
end

"""
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from an exponential
distribution with rate 1. Returns a NamedTuple with the following fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
rand_exp = negate(log_plus_one(negate(rand_uniform)))
return (; output_state=seed, output=rand_exp)
end

# functional ops
@noinline function return_(
results::Union{TracedRArray,TracedRNumber}...;
Expand Down
52 changes: 52 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,55 @@ end
@reactant_overlay @noinline function Random.default_rng()
return call_with_reactant(TracedRandom.default_rng)
end

## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
## We can't directly overlay that call without breaking the semantics of inplace update
for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
overload_randfun = Symbol(:overload_, randfun)
overload_randfun! = Symbol(:overload_, randfun!)

@eval begin
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T}
return TracedRandom.$(overload_randfun)(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, dim1::Integer, dims::Integer...
)
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T}
return TracedRandom.$(overload_randfun)(rng, T)
end

# inplace
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AnyTracedRArray
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# warn about direct writing to arrays
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AbstractArray
)
@warn "Directly writing to an array using Random.jl functions inside \
ReactantInterpreter will generate a constant array in the IR. Use with \
caution." maxlog = 1
return Random.$(randfun!)(rng, A)
end
end
end
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Reactant
using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Random: Random
using Random: Random, AbstractRNG

using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`
Expand Down
116 changes: 68 additions & 48 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ using ..Reactant:
TracedUtils,
Ops,
ConcreteRArray
using Random: Random
using Random: Random, AbstractRNG

function Random.seed!(rng::TracedRNG, seed::Number)
if seed isa TracedRNumber
error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \
`TracedRArray` of the appropriate size instead.")
end

seed = reinterpret(UInt64, Random.hash_seed(seed))
seed = if Reactant.within_reactant_interpreter()
TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)])
Expand All @@ -26,6 +31,14 @@ function Random.seed!(rng::TracedRNG, seed::Number)
return Random.seed!(rng, seed)
end

function Random.seed!(rng::TracedRNG, seed::AbstractArray{<:Integer,1})
return Random.seed!(rng, UInt64.(seed))
end

function Random.seed!(rng::TracedRNG, seed::AbstractArray{UInt64,1})
return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed))
end

function Random.seed!(
rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
)
Expand All @@ -43,87 +56,94 @@ function default_rng()
return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT")
end

function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
rng_algorithm(rng::TracedRNG) = rng.algorithm
rng_algorithm(::AbstractRNG) = "DEFAULT"

function internal_overload_rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
rng.seed = res.output_state
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
return A
end

function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
function internal_overload_randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
Random.rand!(rng, A)
scaled_uniform = Ops.subtract(
Ops.multiply(A, Ops.constant(fill(T(2), size(A)))),
Ops.constant(fill(T(1), size(A))),
)
probit = Ops.erf_inv(scaled_uniform)
rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A))))
TracedUtils.set_mlir_data!(A, rand_normal.mlir_data)
res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm)
rng.seed = res.output_state
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
return A
end

function Random.randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
function internal_overload_randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
Random.rand!(rng, A)
TracedUtils.set_mlir_data!(
A,
Ops.negate(
Ops.log_plus_one(Ops.negate(TracedUtils.materialize_traced_array(A)))
).mlir_data,
)
res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm)
rng.seed = res.output_state
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
return A
end

for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
overload_randfun = Symbol(:internal_overload_, randfun)
overload_randfun! = Symbol(:internal_overload_, randfun!)

@eval begin
function Random.$(randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T}
return Random.$(randfun!)(rng, TracedRArray{T,length(dims)}((), nothing, dims))
function $(overload_randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T}
return $(overload_randfun!)(
rng, TracedRArray{T,length(dims)}((), nothing, dims)
)
end

function Random.$(randfun)(rng::TracedRNG, dims::Dims)
return Random.$(randfun)(rng, Float64, dims)
function $(overload_randfun)(rng::TracedRNG, dims::Dims)
return $(overload_randfun)(rng, Float64, dims)
end

function Random.$(randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...)
return Random.$(randfun)(rng, Dims((dim1, dims...)))
function $(overload_randfun)(rng::TracedRNG, dim1::Integer, dims::Integer...)
return $(overload_randfun)(rng, Dims((dim1, dims...)))
end

function Random.$(randfun)(
function $(overload_randfun)(
rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
return Random.$(randfun)(rng, T, Dims((dim1, dims...)))
return $(overload_randfun)(rng, T, Dims((dim1, dims...)))
end

Random.$(randfun!)(A::AnyTracedRArray) = Random.$(randfun!)(default_rng(), A)
$(overload_randfun!)(A::AnyTracedRArray) = $(overload_randfun!)(default_rng(), A)

# scalars
function Random.$(randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T}
A = promote_to(TracedRArray{T,0}, fill(T(0)))
Random.$(randfun!)(rng, A)
return A[]
end

# Non-Traced RNGs if used will lead to disastrous performance. We attempt to fix
# that but with a warning
function Random.$(randfun!)(rng::Random.AbstractRNG, A::AnyTracedRArray)
@warn "`rng` is not a `TracedRNG`. We will use this to seed the `TracedRNG` \
instead of generating samples from this RNG type." maxlog = 1
seed = promote_to(TracedRArray{UInt64,1}, rand(rng, UInt64, 2))
trng = TracedRNG(seed, "DEFAULT")
return Random.$(randfun!)(trng, A)
function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T}
A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0)))
$(overload_randfun!)(rng, A)
return TracedRNumber{T}((), A.mlir_data)
end
end
end

# resolve ambiguities
for randfun in (:randn, :randexp)
@eval function Random.$(randfun)(rng::TracedRNG, T::Random.BitFloatType)
A = promote_to(TracedRArray{T,0}, fill(T(0)))
Random.randn!(rng, A)
return A[]
# call from overlay-ed variants. we write this with 2 tiers -- overload_* and
# internal_overload_* -- to avoid method ambiguities
for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!)
overload_randfun = Symbol(:overload_, randfun)
internal_overload_randfun = Symbol(:internal_overload_, randfun)
@eval begin
function $(overload_randfun)(rng::AbstractRNG, args...)
seed_uint64 = Array{UInt64}(undef, 2)
sampler = Random.Sampler(rng, UInt64, Val(1))
seed_uint64[1] = rand(rng, sampler)
seed_uint64[2] = rand(rng, sampler)
# XXX: Ideally the following should just work but currently it gives an illegal
# instruction error. Maybe an issue with Julia's AbsInt?
# Random.rand!(rng, seed_uint64)
rng = TracedRNG(
TracedUtils.promote_to(TracedRArray{UInt64,1}, seed_uint64),
rng_algorithm(rng),
)
return $(internal_overload_randfun)(rng, args...)
end

function $(overload_randfun)(rng::TracedRNG, args...)
return $(internal_overload_randfun)(rng, args...)
end
end
end

Expand Down
Loading

0 comments on commit 6c21721

Please sign in to comment.