Skip to content

Commit

Permalink
feat: use the override macro
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 15, 2024
1 parent d29f9ec commit dd1f9e6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
12 changes: 0 additions & 12 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,6 @@ function set_reactant_abi(
return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods)
end

# ensures we are not generating a constant array in the trace
# https://github.com/EnzymeAD/Reactant.jl/issues/356
if (f === Random.default_rng || f === default_rng) && length(argtypes) == 1
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing : Any[:($(default_rng_inside_interpreter))],
Any[Core.Const(default_rng_inside_interpreter)],
)
return abstract_call_known(
interp, default_rng_inside_interpreter, arginfo2, si, sv, max_methods
)
end

return Base.@invoke abstract_call_known(
interp::AbstractInterpreter,
f::Any,
Expand Down
9 changes: 6 additions & 3 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@ TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT")

default_rng() = TracedRNG()
function default_rng_inside_interpreter()
return TracedRNG(promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT")
return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT")
end

@reactant_override @noinline Random.default_rng() = default_rng_inside_interpreter()
@reactant_override @noinline default_rng() = default_rng_inside_interpreter()

# XXX: Currently we get an illegal instruction if we don't call Random.default_rng()

function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
rng.seed = res.output_state
set_mlir_data!(A, res.output.mlir_data)
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
return A
end

Expand All @@ -49,7 +52,7 @@ function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
)
probit = Ops.erf_inv(scaled_uniform)
rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A))))
set_mlir_data!(A, rand_normal.mlir_data)
TracedUtils.set_mlir_data!(A, rand_normal.mlir_data)
return A
end

Expand Down

0 comments on commit dd1f9e6

Please sign in to comment.