Skip to content

Commit

Permalink
fix: throw errors for now instead of crashing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2024
1 parent e370200 commit ba82aa5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
24 changes: 15 additions & 9 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Random.$(randfun)(rng, T, dims)
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
Expand All @@ -68,9 +70,11 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Random.$(randfun)(rng, T, dim1, dims...)
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dim1, dims...)
end

# scalars
Expand All @@ -80,9 +84,11 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Random.$(randfun)(rng, T)
error("Reactant doesn't support sampling of $(T) with the current interpreter.")
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T)
end

# inplace
Expand Down
6 changes: 4 additions & 2 deletions test/integration/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ using StatsBase, Statistics, HypothesisTests, Distributions
@test contains(repr(hlo), "stablehlo.rng_bit_generator")

fn(x) = begin
rng = MersenneTwister()
# XXX: MersenneTwister without seed leads to illegal instructions
rng = MersenneTwister(0)
Random.rand!(rng, x)
return x
end
hlo = @code_hlo fn(Reactant.to_rarray(rand(Float64, 2, 3)))
@test contains(repr(hlo), "stablehlo.rng_bit_generator")

fn2() = begin
rng = MersenneTwister()
# XXX: MersenneTwister without seed leads to illegal instructions
rng = MersenneTwister(0)
x = zeros(Float64, 2, 3)
Random.rand!(rng, x)
return x
Expand Down

0 comments on commit ba82aa5

Please sign in to comment.