-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: tracing Random.jl functionality correctly #363
Conversation
julia> using Reactant, Random
julia> fn() = randn(Random.default_rng(), 2, 3)
fn (generic function with 1 method)
julia> @code_hlo optimize = false fn()
module {
func.func @main() -> tensor<3x2xf64> {
%c = stablehlo.constant dense<[9454987348304227925, 11257230962712577529]> : tensor<2xui64>
%output_state, %output = stablehlo.rng_bit_generator %c, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x3xui64>)
%0 = stablehlo.convert %output : (tensor<2x3xui64>) -> tensor<2x3xf64>
%cst = stablehlo.constant dense<1.8446744073709552E+19> : tensor<2x3xf64>
%1 = stablehlo.divide %0, %cst : tensor<2x3xf64>
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf64>
%2 = stablehlo.multiply %1, %cst_0 : tensor<2x3xf64>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf64>
%3 = stablehlo.subtract %2, %cst_1 : tensor<2x3xf64>
%4 = chlo.erf_inv %3 : tensor<2x3xf64> -> tensor<2x3xf64>
%cst_2 = stablehlo.constant dense<1.4142135623730951> : tensor<2x3xf64>
%5 = stablehlo.multiply %4, %cst_2 : tensor<2x3xf64>
%6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %6 : tensor<3x2xf64>
}
}
julia> @code_hlo fn()
module {
func.func @main() -> tensor<3x2xf64> {
%cst = stablehlo.constant dense<1.4142135623730951> : tensor<2x3xf64>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf64>
%cst_1 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf64>
%cst_2 = stablehlo.constant dense<1.8446744073709552E+19> : tensor<2x3xf64>
%c = stablehlo.constant dense<[17523564455668573441, 5342821220909967229]> : tensor<2xui64>
%output_state, %output = stablehlo.rng_bit_generator %c, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x3xui64>)
%0 = stablehlo.convert %output : (tensor<2x3xui64>) -> tensor<2x3xf64>
%1 = stablehlo.divide %0, %cst_2 : tensor<2x3xf64>
%2 = stablehlo.multiply %1, %cst_1 : tensor<2x3xf64>
%3 = stablehlo.subtract %2, %cst_0 : tensor<2x3xf64>
%4 = chlo.erf_inv %3 : tensor<2x3xf64> -> tensor<2x3xf64>
%5 = stablehlo.multiply %4, %cst : tensor<2x3xf64>
%6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %6 : tensor<3x2xf64>
}
} |
99e0e76
to
7d17faf
Compare
This is kind of working now. Can I get an initial review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me but should wait until the interpreter lands (which will simplify the override and can just use the macro)
dd1f9e6
to
e99089e
Compare
The overlay mechanism fixed the previous unreachable instruction issue 🎉 |
8af045f
to
22f6810
Compare
9aec916
to
6c21721
Compare
src/Overlay.jl
Outdated
@warn "Directly writing to an array using Random.jl functions inside \ | ||
ReactantInterpreter will generate a constant array in the IR. Use with \ | ||
caution." maxlog = 1 | ||
return Random.$(randfun!)(rng, A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wsmoses my understanding was that this should call the non-overlayed version. But here I get
┌ Warning: Directly writing to an array using Random.jl functions inside ReactantInterpreter will generate a constant array in the IR. Use with caution.
└ @ Reactant /mnt/software/lux/Reactant.jl/src/Overlay.jl:84
Unreachable reached at 0x7a64877b6b16
[1417783] signal 4 (2): Illegal instruction
in expression starting at REPL[7]:1
fn2 at ./REPL[6]:4 [inlined]
opaque closure at ./<missing>:0
unknown function (ip: 0x7a64877b6bff)
fn2 at ./REPL[6]:2 [inlined]
call_with_reactant at /mnt/software/lux/Reactant.jl/src/utils.jl:0
#8 at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:210
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
unknown function (ip: 0x7a64877b6316)
#make_mlir_fn#1 at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:197
make_mlir_fn at /mnt/software/lux/Reactant.jl/src/TracedUtils.jl:117 [inlined]
#10 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:295 [inlined]
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
#9 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:294 [inlined]
mmodule! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:92
unknown function (ip: 0x7a64877b5976)
#compile_mlir!#8 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:291
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:290 [inlined]
#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:285 [inlined]
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x7a64877b2a76)
#compile_mlir#5 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
compile_mlir at /mnt/software/lux/Reactant.jl/src/Compiler.jl:280
unknown function (ip: 0x7a64877b0a66)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:625
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10705 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_FGbh7.so (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_15174 at /mnt/.julia/compiled/v1.11/REPL/u0gqU_FGbh7.so (unknown line)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73406.1 at /home/avikpal/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x7a65341e2e07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 48616895 (Pool: 48615481; Big: 1414); GC: 47
[1] 1417783 illegal hardware instruction (core dumped) julia --project=envs --threads=4 --check-bounds=yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using this trick #369 (comment) seems to work correctly (though this has other weird edge-cases)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely a general issue with any kind of recursion of an overlayed function. Do we have a way to force using the NativeInterpreter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/stdlibs/Random.jl
Outdated
function $(overload_randfun)(rng::AbstractRNG, args...) | ||
# XXX: Ideally the following should just work but currently it gives an illegal | ||
# instruction error. Maybe an issue with Julia's AbsInt? | ||
# seed_uint64 = rand(rng, UInt64, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same problem here. calling this leads to the illegal instruction
src/Overlay.jl
Outdated
if T <: ReactantPrimitive | ||
return TracedRandom.$(overload_randfun)(rng, T, dims) | ||
end | ||
error("Reactant doesn't support sampling of $(T) with the current interpreter.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
error("Reactant doesn't support sampling of $(T) with the current interpreter.") | |
return error( | |
"Reactant doesn't support sampling of $(T) with the current interpreter." | |
) |
src/Overlay.jl
Outdated
if T <: ReactantPrimitive | ||
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) | ||
end | ||
error("Reactant doesn't support sampling of $(T) with the current interpreter.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
error("Reactant doesn't support sampling of $(T) with the current interpreter.") | |
return error( | |
"Reactant doesn't support sampling of $(T) with the current interpreter." | |
) |
src/Overlay.jl
Outdated
if T <: ReactantPrimitive | ||
return TracedRandom.$(overload_randfun)(rng, T) | ||
end | ||
error("Reactant doesn't support sampling of $(T) with the current interpreter.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
error("Reactant doesn't support sampling of $(T) with the current interpreter.") | |
return error( | |
"Reactant doesn't support sampling of $(T) with the current interpreter." | |
) |
I am working around the AbsInt issues for now, but those would be nice to sort out, especially for handling the wrapped arrays. Also, I am explicitly throwing errors in the cases where otherwise Julia would crash with an invalid instruction |
ba82aa5
to
eaa1f30
Compare
TODOs
Overlay seems to act weird. Can't switch to non-reactant interpreter nicelyUsing a workaround for nowRandom123
to generate specific RNGs (Threefry and Philox)