From 8249408c17e10b4073e80fe1e9558673e05d119a Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 10:39:11 -0600 Subject: [PATCH] Revert "feat: tracing Random.jl functionality correctly (#363)" This reverts commit 94e9576f5610f244a8ed379fced26b272c9f7f0b. --- .github/workflows/CI.yml | 6 - Project.toml | 9 +- docs/make.jl | 1 - docs/src/.vitepress/config.mts | 4 +- docs/src/api/internal.md | 12 -- ext/ReactantRandom123Ext.jl | 11 -- src/Ops.jl | 141 +------------ src/Overlay.jl | 95 +-------- src/Reactant.jl | 11 +- .../LinearAlgebra.jl => linear_algebra.jl} | 0 src/stdlibs/Random.jl | 168 ---------------- src/utils.jl | 7 +- test/Project.toml | 5 - test/integration/random.jl | 187 ------------------ test/nn/lux.jl | 2 +- test/ops.jl | 46 +---- test/runtests.jl | 1 - 17 files changed, 16 insertions(+), 690 deletions(-) delete mode 100644 docs/src/api/internal.md delete mode 100644 ext/ReactantRandom123Ext.jl rename src/{stdlibs/LinearAlgebra.jl => linear_algebra.jl} (100%) delete mode 100644 src/stdlibs/Random.jl delete mode 100644 test/integration/random.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 508ff06b9..66882fb6a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,12 +50,6 @@ jobs: version: '1.10' assertions: true test_group: neural_networks - - os: ubuntu-20.04 - arch: x64 - libReactant: packaged - version: '1.10' - assertions: true - test_group: integration - os: ubuntu-20.04 arch: x86 libReactant: packaged diff --git a/Project.toml b/Project.toml index b6851867b..f5dcc854d 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,6 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Preferences = "21216c6a-2e73-6563-6e65-726566657250" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" Scratch = "6c6a2e73-6563-6170-7368-637461726353" @@ -24,19 +23,17 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources] -ReactantCore = {path = "lib/ReactantCore"} +[sources.ReactantCore] +path = "lib/ReactantCore" [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" -ReactantRandom123Ext = "Random123" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -54,8 +51,6 @@ LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" -Random = "1.10" -Random123 = "1.7" ReactantCore = "0.1.3" Reactant_jll = "0.0.26" Scratch = "1.2" diff --git a/docs/make.jl b/docs/make.jl index fcbaca60e..7515a566d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -43,7 +43,6 @@ pages = [ ], "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", - "Internal API" => "api/internal.md", ], ] diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 1dc25f2ad..942a9415d 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -78,8 +78,7 @@ export default defineConfig({ { text: "MLIR API", link: "/api/mlirc" }, { text: "XLA", link: "/api/xla" }, ], - }, - { text: "Internal API", link: "/api/internal" }, + } ], }, { @@ -133,7 +132,6 @@ export default defineConfig({ { text: "XLA", link: "/api/xla" }, ], }, - { text: "Internal API", link: "/api/internal" }, ], }, }, diff --git a/docs/src/api/internal.md b/docs/src/api/internal.md deleted file mode 100644 index a8788e5fb..000000000 --- a/docs/src/api/internal.md +++ /dev/null @@ -1,12 +0,0 @@ -```@meta -CollapsedDocStrings = true -``` - -# Internal API - -These functions are not part of the public API and are subject to change at any time. - -```@docs -Reactant.REDUB_ARGUMENTS_NAME -Reactant.within_reactant_interpreter -``` diff --git a/ext/ReactantRandom123Ext.jl b/ext/ReactantRandom123Ext.jl deleted file mode 100644 index d701fdc7e..000000000 --- a/ext/ReactantRandom123Ext.jl +++ /dev/null @@ -1,11 +0,0 @@ -module ReactantRandom123Ext - -using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x -using Reactant: TracedRandom - -TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY" -TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" -TracedRandom.rng_algorithm(::Philox4x) = "PHILOX" -TracedRandom.rng_algorithm(::Philox2x) = "PHILOX" - -end diff --git a/src/Ops.jl b/src/Ops.jl index 18ab2d7d4..fa8c17b3c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1016,150 +1016,19 @@ end end # random ops -""" - rng_bit_generator( - ::Type{T}, - seed::TracedRArray{UInt64,1}, - shape; - algorithm::String="DEFAULT", - location=mlir_stacktrace("rand", @__FILE__, @__LINE__), - ) - -Generate a random array of type `T` with the given shape and seed from a uniform random -distribution between 0 and 1. Returns a NamedTuple with the following fields: - -- `output_state`: The state of the random number generator after the operation. -- `output`: The generated array. - -# Arguments - -- `T`: The type of the generated array. -- `seed`: The seed for the random number generator. -- `shape`: The shape of the generated array. -- `algorithm`: The algorithm to use for generating the random numbers. Defaults to - "DEFAULT". Other options include "PHILOX" and "THREE_FRY". -""" @noinline function rng_bit_generator( - ::Type{T}, seed::TracedRArray{UInt64,1}, shape; algorithm::String="DEFAULT", location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) where {T<:Integer} - @assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY") - if algorithm == "PHILOX" - @assert length(seed) ∈ (2, 3) - elseif algorithm == "THREE_FRY" - @assert length(seed) == 2 - end - - output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) - output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64)) +) + output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape) rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm) - op = stablehlo.rng_bit_generator( - seed.mlir_data; output, output_state, rng_algorithm, location - ) + op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location) return (; - output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)), - output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)), - ) -end - -@noinline function rng_bit_generator( - ::Type{T}, - seed::TracedRArray{UInt64,1}, - shape; - algorithm::String="DEFAULT", - location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), -) where {T<:AbstractFloat} - nbits = sizeof(T) * 8 - uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64) - (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location) - output = divide( - convert(TracedRArray{T,ndims(output)}, output), - constant(fill(T(typemax(uT)), Tuple(shape)); location), - ) - return (; output_state, output) -end - -""" - randn( - ::Type{T}, - seed::TracedRArray{UInt64,1}, - shape; - algorithm::String="DEFAULT", - location=mlir_stacktrace("rand", @__FILE__, @__LINE__), + output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)), + output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape), ) - -Generate a random array of type `T` with the given shape and seed from a standard normal -distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following -fields: - -- `output_state`: The state of the random number generator after the operation. -- `output`: The generated array. - -# Arguments - -- `T`: The type of the generated array. -- `seed`: The seed for the random number generator. -- `shape`: The shape of the generated array. -- `algorithm`: The algorithm to use for generating the random numbers. Defaults to - "DEFAULT". Other options include "PHILOX" and "THREE_FRY". -""" -@noinline function randn( - ::Type{T}, - seed::TracedRArray{UInt64,1}, - shape; - algorithm::String="DEFAULT", - location=mlir_stacktrace("rand", @__FILE__, @__LINE__), -) where {T} - res = rng_bit_generator(T, seed, shape; algorithm, location) - rand_uniform = res.output - seed = res.output_state - scaled_uniform = subtract( - multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))), - constant(fill(T(1), size(rand_uniform))), - ) - probit = erf_inv(scaled_uniform) - rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform)))) - return (; output_state=seed, output=rand_normal) -end - -""" - randexp( - ::Type{T}, - seed::TracedRArray{UInt64,1}, - shape; - algorithm::String="DEFAULT", - location=mlir_stacktrace("rand", @__FILE__, @__LINE__), - ) - -Generate a random array of type `T` with the given shape and seed from an exponential -distribution with rate 1. Returns a NamedTuple with the following fields: - -- `output_state`: The state of the random number generator after the operation. -- `output`: The generated array. - -# Arguments - -- `T`: The type of the generated array. -- `seed`: The seed for the random number generator. -- `shape`: The shape of the generated array. -- `algorithm`: The algorithm to use for generating the random numbers. Defaults to - "DEFAULT". Other options include "PHILOX" and "THREE_FRY". -""" -@noinline function randexp( - ::Type{T}, - seed::TracedRArray{UInt64,1}, - shape; - algorithm::String="DEFAULT", - location=mlir_stacktrace("rand", @__FILE__, @__LINE__), -) where {T} - res = rng_bit_generator(T, seed, shape; algorithm, location) - rand_uniform = res.output - seed = res.output_state - rand_exp = negate(log_plus_one(negate(rand_uniform))) - return (; output_state=seed, output=rand_exp) end # functional ops diff --git a/src/Overlay.jl b/src/Overlay.jl index b9785b7fa..6d4752acd 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -3,15 +3,6 @@ # correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved # we should move all the reactant_overrides to relevant files. -# Helper Function to determine if we are inside the ReactantInterpreter -""" - within_reactant_interpreter() - -Returns `true` if we are currently inside the ReactantInterpreter. -""" -@noinline within_reactant_interpreter() = false -@reactant_overlay @noinline within_reactant_interpreter() = true - # Compiling within a compile should return simply the original function @reactant_overlay function Compiler.compile( f, args; client=nothing, optimize=true, sync=false @@ -19,7 +10,7 @@ Returns `true` if we are currently inside the ReactantInterpreter. return f end -# Enzyme.jl overlays +# Enzyme overrides @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} @@ -31,87 +22,3 @@ end ) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end - -# Random.jl overlays -@reactant_overlay @noinline function Random.default_rng() - return call_with_reactant(TracedRandom.default_rng) -end - -## Only problematic edge case here is the direct `(rng, A::AbstractArray)` call -## We can't directly overlay that call without breaking the semantics of inplace update -for randfun in (:rand, :randn, :randexp) - randfun! = Symbol(randfun, :!) - overload_randfun = Symbol(:overload_, randfun) - overload_randfun! = Symbol(:overload_, randfun!) - - @eval begin - @reactant_overlay @noinline function Random.$(randfun)( - rng::AbstractRNG, ::Type{T}, dims::Dims - ) where {T} - if T <: ReactantPrimitive - return TracedRandom.$(overload_randfun)(rng, T, dims) - end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T, dims) - end - - @reactant_overlay @noinline function Random.$(randfun)( - rng::AbstractRNG, dim1::Integer, dims::Integer... - ) - return TracedRandom.$(overload_randfun)(rng, dim1, dims...) - end - - @reactant_overlay @noinline function Random.$(randfun)( - rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... - ) where {T} - if T <: ReactantPrimitive - return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) - end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T, dim1, dims...) - end - - # scalars - @reactant_overlay @noinline function Random.$(randfun)( - rng::AbstractRNG, ::Type{T}=Float64 - ) where {T} - if T <: ReactantPrimitive - return TracedRandom.$(overload_randfun)(rng, T) - end - return error( - "Reactant doesn't support sampling of $(T) with the current interpreter." - ) - # XXX: The following will lead to illegal instruction - # @warn "Reactant doesn't support sampling of $(T) with the current \ - # interpreter. Falling back to native interpreter." maxlog = 1 - # return Random.$(randfun)(rng, T) - end - - # inplace - @reactant_overlay @noinline function Random.$(randfun!)( - rng::AbstractRNG, A::AnyTracedRArray - ) - return TracedRandom.$(overload_randfun!)(rng, A) - end - - # XXX: Uncomment once AbsInt issues with recursive calls are resolved - # @reactant_overlay @noinline function Random.$(randfun!)( - # rng::AbstractRNG, A::AbstractArray - # ) - # @warn "Directly writing to an array using Random.jl functions inside \ - # ReactantInterpreter will generate a constant array in the IR. Use with \ - # caution." maxlog = 1 - # return Random.$(randfun!)(rng, A) - # end - end -end diff --git a/src/Reactant.jl b/src/Reactant.jl index bea015074..e7c8805de 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,8 +3,6 @@ module Reactant using ReactantCore: ReactantCore, @trace, MissingTracedValue using LinearAlgebra: LinearAlgebra -using Random: Random, AbstractRNG - using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` @@ -124,14 +122,7 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") -mutable struct TracedRNG <: Random.AbstractRNG - seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} - const algorithm::String -end - -# StdLib Overloads -include("stdlibs/LinearAlgebra.jl") -include("stdlibs/Random.jl") +include("linear_algebra.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/stdlibs/LinearAlgebra.jl b/src/linear_algebra.jl similarity index 100% rename from src/stdlibs/LinearAlgebra.jl rename to src/linear_algebra.jl diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl deleted file mode 100644 index 271b78f80..000000000 --- a/src/stdlibs/Random.jl +++ /dev/null @@ -1,168 +0,0 @@ -module TracedRandom - -# Implementation based on the following: -# 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl -# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl - -using ..Reactant: - Reactant, - TracedRArray, - TracedRNumber, - TracedRNG, - AnyTracedRArray, - Reactant, - TracedUtils, - Ops, - ConcreteRArray -using Random: Random, AbstractRNG - -@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) - # XXX: We should really be able to call this here. But with our AbsInt it leads to a - # segfault. So we'll just call it in the rand! method. - # return rand(rng, UInt64, 2) - seed = Array{UInt64}(undef, 2) - Random.rand!(rng, seed) - return seed -end - -function Random.seed!(rng::TracedRNG, seed::Number) - if seed isa TracedRNumber - error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \ - `TracedRArray` of the appropriate size instead.") - end - - seed = reinterpret(UInt64, Random.hash_seed(seed)) - seed = if Reactant.within_reactant_interpreter() - TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) - else - ConcreteRArray(seed[1:length(rng.seed)]) - end - return Random.seed!(rng, seed) -end - -function Random.seed!(rng::TracedRNG, seed::AbstractArray{<:Integer,1}) - return Random.seed!(rng, UInt64.(seed)) -end - -function Random.seed!(rng::TracedRNG, seed::AbstractArray{UInt64,1}) - return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed)) -end - -function Random.seed!( - rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}} -) - rng.seed = seed - return rng -end - -@noinline TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) -@noinline TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") - -@noinline function default_rng() - Reactant.within_reactant_interpreter() || return TracedRNG() - return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") -end - -@noinline rng_algorithm(rng::TracedRNG) = rng.algorithm -@noinline rng_algorithm(::AbstractRNG) = "DEFAULT" - -@noinline function internal_overload_rand!( - rng::TracedRNG, A::AnyTracedRArray{T,N} -) where {T,N} - length(A) == 0 && return A - res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) - rng.seed = res.output_state - TracedUtils.set_mlir_data!(A, res.output.mlir_data) - return A -end - -@noinline function internal_overload_randn!( - rng::TracedRNG, A::AnyTracedRArray{T,N} -) where {T,N} - length(A) == 0 && return A - res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm) - rng.seed = res.output_state - TracedUtils.set_mlir_data!(A, res.output.mlir_data) - return A -end - -@noinline function internal_overload_randexp!( - rng::TracedRNG, A::AnyTracedRArray{T,N} -) where {T,N} - length(A) == 0 && return A - res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm) - rng.seed = res.output_state - TracedUtils.set_mlir_data!(A, res.output.mlir_data) - return A -end - -for randfun in (:rand, :randn, :randexp) - randfun! = Symbol(randfun, :!) - overload_randfun = Symbol(:internal_overload_, randfun) - overload_randfun! = Symbol(:internal_overload_, randfun!) - - @eval begin - @noinline function $(overload_randfun)( - rng::TracedRNG, ::Type{T}, dims::Dims - ) where {T} - return $(overload_randfun!)( - rng, TracedRArray{T,length(dims)}((), nothing, dims) - ) - end - - @noinline function $(overload_randfun)(rng::TracedRNG, dims::Dims) - return $(overload_randfun)(rng, Float64, dims) - end - - @noinline function $(overload_randfun)( - rng::TracedRNG, dim1::Integer, dims::Integer... - ) - return $(overload_randfun)(rng, Dims((dim1, dims...))) - end - - @noinline function $(overload_randfun)( - rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer... - ) where {T} - return $(overload_randfun)(rng, T, Dims((dim1, dims...))) - end - - @noinline function $(overload_randfun!)(A::AnyTracedRArray) - return $(overload_randfun!)(default_rng(), A) - end - - # scalars - @noinline function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} - A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0))) - $(overload_randfun!)(rng, A) - return TracedRNumber{T}((), A.mlir_data) - end - end -end - -# call from overlay-ed variants. we write this with 2 tiers -- overload_* and -# internal_overload_* -- to avoid method ambiguities -for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!) - overload_randfun = Symbol(:overload_, randfun) - internal_overload_randfun = Symbol(:internal_overload_, randfun) - @eval begin - @noinline function $(overload_randfun)(rng::AbstractRNG, args...) - rng = TracedRNG( - TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed(rng)), - rng_algorithm(rng), - ) - return $(internal_overload_randfun)(rng, args...) - end - - @noinline function $(overload_randfun)(rng::TracedRNG, args...) - return $(internal_overload_randfun)(rng, args...) - end - end -end - -# TODO: At some later point we might want to implement the sampler API as well since it -# makes all RNG implementation work by default. From the post-optimize IR we need to -# confirm that the dynamic_update_slice calls are optimized away into a single -# `stablehlo.rng_bit_generator` call -- confirm that this should be the case based on -# how the seeding should work? - -end diff --git a/src/utils.jl b/src/utils.jl index b8eb02849..16b784d58 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -99,8 +99,7 @@ function should_rewrite_ft(@nospecialize(ft)) # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions if has_ancestor(mod, Reactant.Ops) || has_ancestor(mod, Reactant.TracedUtils) || - has_ancestor(mod, Reactant.MLIR) || - has_ancestor(mod, Reactant.TracedRandom) + has_ancestor(mod, Reactant.MLIR) return false end end @@ -306,7 +305,7 @@ function call_with_reactant_generator( overdubbed_codelocs = Int32[] # No method could be found (including in our method table), bail with an error - if lookup_result === nothing + if lookup_result == nothing return stub(world, source, method_error) end @@ -502,7 +501,7 @@ function call_with_reactant_generator( # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). - # Opaque closures also require taking the function argument. We can work around the latter + # Opaque closures also require takign the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if false && Base.issingletontype(args[1]) res = Core._call_in_world_total( diff --git a/test/Project.toml b/test/Project.toml index d8861a1aa..e7e33313b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,12 +2,10 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -17,13 +15,10 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Random123 = "74087812-796a-5b5d-8853-05524746bad3" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/integration/random.jl b/test/integration/random.jl deleted file mode 100644 index 275e0e244..000000000 --- a/test/integration/random.jl +++ /dev/null @@ -1,187 +0,0 @@ -using Reactant, Test, Random, Random123, StableRNGs, Statistics -using StatsBase, Statistics, HypothesisTests, Distributions - -# First Testing overlay works correctly -@testset "Random.jl Overlay" begin - hlo = @code_hlo rand(Float32, 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - hlo = @code_hlo rand(MersenneTwister(), Float32, 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - hlo = @code_hlo rand(2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - hlo = @code_hlo rand(MersenneTwister(), 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - hlo = @code_hlo rand(MersenneTwister(), Float64, (2, 3)) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - hlo = @code_hlo rand(MersenneTwister(), Float64) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - hlo = @code_hlo rand(MersenneTwister()) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - fn(x) = begin - # XXX: MersenneTwister without seed leads to illegal instructions - rng = MersenneTwister(0) - Random.rand!(rng, x) - return x - end - hlo = @code_hlo fn(Reactant.to_rarray(rand(Float64, 2, 3))) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - - fn2() = begin - # XXX: MersenneTwister without seed leads to illegal instructions - rng = MersenneTwister(0) - x = zeros(Float64, 2, 3) - Random.rand!(rng, x) - return x - end - hlo = @code_hlo fn2() - @test !contains(repr(hlo), "stablehlo.rng_bit_generator") -end - -@testset "Random123" begin - hlo = @code_hlo rand(Random123.Threefry4x(), Float32, 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - @test contains(repr(hlo), "THREE_FRY") - - hlo = @code_hlo rand(Random123.Threefry2x(), Float64, 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - @test contains(repr(hlo), "THREE_FRY") - - hlo = @code_hlo rand(Random123.Philox4x(), Float64, 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - @test contains(repr(hlo), "PHILOX") - - hlo = @code_hlo rand(Random123.Philox2x(), Float64, 2, 3) - @test contains(repr(hlo), "stablehlo.rng_bit_generator") - @test contains(repr(hlo), "PHILOX") -end - -# Next we test that the random number generators actually generate data from the correct -# distributions -@testset "Uniform Random" begin - @testset "Deterministic Seed" begin - seed1 = ConcreteRArray(UInt64[1, 3]) - seed2 = ConcreteRArray(UInt64[1, 5]) - - fn(seed) = begin - rng = Random.default_rng() - Random.seed!(rng, seed) - return rand(rng, 10000) - end - - fn_compiled = @compile fn(seed1) - @test fn_compiled(seed1) ≈ fn_compiled(seed1) - @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) - end - - @testset "Correct Distribution" begin - X = Array(@jit(rand(StableRNG(0), 10000))) - ks_test = ExactOneSampleKSTest(X, Uniform(0.0, 1.0)) - @test pvalue(ks_test) > 0.05 - end - - @testset "AutoCorrelation" begin - X = Array(@jit(rand(StableRNG(0), 10000))) - autocorr = cor(X[1:(end - 1)], X[2:end]) - @test abs(autocorr) < 0.05 - end - - @testset "Correct Range" begin - X = Array(@jit(rand(StableRNG(0), 10000))) - X_min, X_max = extrema(X) - @test X_min ≥ 0.0 - @test X_max ≤ 1.0 - end - - @testset "Mean & Variance" begin - X = Array(@jit(rand(StableRNG(0), 10000))) - μ = mean(X) - σ² = var(X) - @test μ ≈ 0.5 atol = 0.05 rtol = 0.05 - @test σ² ≈ (1//12) atol = 0.05 rtol = 0.05 - end -end - -@testset "Normal Distribution" begin - @testset "Deterministic Seed" begin - seed1 = ConcreteRArray(UInt64[1, 3]) - seed2 = ConcreteRArray(UInt64[1, 5]) - - fn(seed) = begin - rng = Random.default_rng() - Random.seed!(rng, seed) - return randn(rng, 10000) - end - - fn_compiled = @compile fn(seed1) - @test fn_compiled(seed1) ≈ fn_compiled(seed1) - @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) - end - - @testset "Correct Distribution" begin - X = Array(@jit(randn(StableRNG(0), 10000))) - sw_test = ShapiroWilkTest(X) - @test pvalue(sw_test) > 0.05 - end - - @testset "AutoCorrelation" begin - X = Array(@jit(randn(StableRNG(0), 10000))) - autocorr = cor(X[1:(end - 1)], X[2:end]) - @test abs(autocorr) < 0.05 - end - - @testset "Mean & Variance" begin - X = Array(@jit(randn(StableRNG(0), 10000))) - μ = mean(X) - σ² = var(X) - @test μ ≈ 0.0 atol = 0.05 rtol = 0.05 - @test σ² ≈ 1.0 atol = 0.05 rtol = 0.05 - end -end - -@testset "Exponential Distribution" begin - @testset "Deterministic Seed" begin - seed1 = ConcreteRArray(UInt64[1, 3]) - seed2 = ConcreteRArray(UInt64[1, 5]) - - fn(seed) = begin - rng = Random.default_rng() - Random.seed!(rng, seed) - return randexp(rng, 10000) - end - - fn_compiled = @compile fn(seed1) - @test fn_compiled(seed1) ≈ fn_compiled(seed1) - @test !(all(Array(fn_compiled(seed1)) .≈ Array(fn_compiled(seed2)))) - end - - @testset "Correct Distribution" begin - X = Array(@jit(randexp(StableRNG(0), 10000))) - ks_test = ExactOneSampleKSTest(X, Exponential(1.0)) - @test pvalue(ks_test) > 0.05 - end - - @testset "AutoCorrelation" begin - X = Array(@jit(randexp(StableRNG(0), 10000))) - autocorr = cor(X[1:(end - 1)], X[2:end]) - @test abs(autocorr) < 0.05 - end - - @testset "Correct Range" begin - X = Array(@jit(randexp(StableRNG(0), 10000))) - X_min, X_max = extrema(X) - @test X_min ≥ 0.0 - end - - @testset "Mean" begin - X = Array(@jit(randexp(StableRNG(0), 10000))) - μ = mean(X) - @test μ ≈ 1.0 atol = 0.05 rtol = 0.05 - end -end diff --git a/test/nn/lux.jl b/test/nn/lux.jl index 7916ce10f..49fa37f52 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -8,7 +8,7 @@ end function gradient_loss_function(model, x, y, ps, st) dps = Enzyme.make_zero(ps) _, res = Enzyme.autodiff( - set_runtime_activity(ReverseWithPrimal), + ReverseWithPrimal, loss_function, Active, Const(model), diff --git a/test/ops.jl b/test/ops.jl index 82ec4cc8b..07f911e88 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -538,50 +538,8 @@ end end @testset "rng_bit_generator" begin - genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4]) - genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4]) - genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4]) - genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4]) - genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4]) - - @testset for (alg, sz) in - [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] - seed = ConcreteRArray(zeros(UInt64, sz)) - - res = @jit genInt32(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Int32,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genInt64(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Int64,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genUInt64(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{UInt64,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genFloat32(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Float32,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genFloat64(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Float64,2} - @test size(res.output) == (2, 4) - end + # seed = ConcreteRArray([0, 0]) + # @jit Ops.rng_bit_generator(seed, [2]) end @testset "round_nearest_afz" begin diff --git a/test/runtests.jl b/test/runtests.jl index 68dfcaead..fddc963ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,7 +61,6 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") - @safetestset "Random" include("integration/random.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"