Skip to content

Commit

Permalink
Revert "feat: tracing Random.jl functionality correctly (#363)"
Browse files Browse the repository at this point in the history
This reverts commit 94e9576.
  • Loading branch information
wsmoses authored Dec 18, 2024
1 parent 94e9576 commit 8249408
Show file tree
Hide file tree
Showing 17 changed files with 16 additions and 690 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ jobs:
version: '1.10'
assertions: true
test_group: neural_networks
- os: ubuntu-20.04
arch: x64
libReactant: packaged
version: '1.10'
assertions: true
test_group: integration
- os: ubuntu-20.04
arch: x86
libReactant: packaged
Expand Down
9 changes: 2 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Expand All @@ -24,19 +23,17 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources]
ReactantCore = {path = "lib/ReactantCore"}
[sources.ReactantCore]
path = "lib/ReactantCore"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantRandom123Ext = "Random123"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

Expand All @@ -54,8 +51,6 @@ LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Expand Down
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ pages = [
],
"MLIR API" => "api/mlirc.md",
"XLA" => "api/xla.md",
"Internal API" => "api/internal.md",
],
]

Expand Down
4 changes: 1 addition & 3 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ export default defineConfig({
{ text: "MLIR API", link: "/api/mlirc" },
{ text: "XLA", link: "/api/xla" },
],
},
{ text: "Internal API", link: "/api/internal" },
}
],
},
{
Expand Down Expand Up @@ -133,7 +132,6 @@ export default defineConfig({
{ text: "XLA", link: "/api/xla" },
],
},
{ text: "Internal API", link: "/api/internal" },
],
},
},
Expand Down
12 changes: 0 additions & 12 deletions docs/src/api/internal.md

This file was deleted.

11 changes: 0 additions & 11 deletions ext/ReactantRandom123Ext.jl

This file was deleted.

141 changes: 5 additions & 136 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1016,150 +1016,19 @@ end
end

# random ops
"""
rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from a uniform random
distribution between 0 and 1. Returns a NamedTuple with the following fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:Integer}
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
if algorithm == "PHILOX"
@assert length(seed) (2, 3)
elseif algorithm == "THREE_FRY"
@assert length(seed) == 2
end

output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64))
)
output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape)
rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm)
op = stablehlo.rng_bit_generator(
seed.mlir_data; output, output_state, rng_algorithm, location
)
op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location)
return (;
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)),
)
end

@noinline function rng_bit_generator(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
) where {T<:AbstractFloat}
nbits = sizeof(T) * 8
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
output = divide(
convert(TracedRArray{T,ndims(output)}, output),
constant(fill(T(typemax(uT)), Tuple(shape)); location),
)
return (; output_state, output)
end

"""
randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)),
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape),
)
Generate a random array of type `T` with the given shape and seed from a standard normal
distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randn(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
scaled_uniform = subtract(
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
constant(fill(T(1), size(rand_uniform))),
)
probit = erf_inv(scaled_uniform)
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
return (; output_state=seed, output=rand_normal)
end

"""
randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)
Generate a random array of type `T` with the given shape and seed from an exponential
distribution with rate 1. Returns a NamedTuple with the following fields:
- `output_state`: The state of the random number generator after the operation.
- `output`: The generated array.
# Arguments
- `T`: The type of the generated array.
- `seed`: The seed for the random number generator.
- `shape`: The shape of the generated array.
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
"""
@noinline function randexp(
::Type{T},
seed::TracedRArray{UInt64,1},
shape;
algorithm::String="DEFAULT",
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
) where {T}
res = rng_bit_generator(T, seed, shape; algorithm, location)
rand_uniform = res.output
seed = res.output_state
rand_exp = negate(log_plus_one(negate(rand_uniform)))
return (; output_state=seed, output=rand_exp)
end

# functional ops
Expand Down
95 changes: 1 addition & 94 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,14 @@
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Helper Function to determine if we are inside the ReactantInterpreter
"""
within_reactant_interpreter()
Returns `true` if we are currently inside the ReactantInterpreter.
"""
@noinline within_reactant_interpreter() = false
@reactant_overlay @noinline within_reactant_interpreter() = true

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme.jl overlays
# Enzyme overrides
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
Expand All @@ -31,87 +22,3 @@ end
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

# Random.jl overlays
@reactant_overlay @noinline function Random.default_rng()
return call_with_reactant(TracedRandom.default_rng)
end

## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
## We can't directly overlay that call without breaking the semantics of inplace update
for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
overload_randfun = Symbol(:overload_, randfun)
overload_randfun! = Symbol(:overload_, randfun!)

@eval begin
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, dim1::Integer, dims::Integer...
)
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
end

@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dim1, dims...)
end

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T)
end

# inplace
@reactant_overlay @noinline function Random.$(randfun!)(
rng::AbstractRNG, A::AnyTracedRArray
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# XXX: Uncomment once AbsInt issues with recursive calls are resolved
# @reactant_overlay @noinline function Random.$(randfun!)(
# rng::AbstractRNG, A::AbstractArray
# )
# @warn "Directly writing to an array using Random.jl functions inside \
# ReactantInterpreter will generate a constant array in the IR. Use with \
# caution." maxlog = 1
# return Random.$(randfun!)(rng, A)
# end
end
end
11 changes: 1 addition & 10 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ module Reactant
using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Random: Random, AbstractRNG

using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

Expand Down Expand Up @@ -124,14 +122,7 @@ include("TracedRArray.jl")

include("ConcreteRArray.jl")

mutable struct TracedRNG <: Random.AbstractRNG
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
const algorithm::String
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")
include("linear_algebra.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

Expand Down
File renamed without changes.
Loading

0 comments on commit 8249408

Please sign in to comment.