Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Test Latest Polyester with Enzyme #70

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.3.23"
version = "0.3.24"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -57,6 +58,7 @@ LuxCore = "0.1.13"
LuxTestUtils = "0.1.15"
Markdown = "1.10"
NNlib = "0.9.13"
Polyester = "0.7.14"
PrecompileTools = "1.2"
Random = "1.10"
ReTestItems = "1.23.1"
Expand Down
1 change: 1 addition & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using PrecompileTools: @recompile_invalidations
using LuxCore: LuxCore
using Markdown: @doc_str
using NNlib: NNlib
using Polyester: @batch
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using Statistics: Statistics, mean, var
Expand Down
11 changes: 7 additions & 4 deletions src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,19 @@ function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:Abstrac
return x_, _vec(rμ), _vec(rσ²)
end

# Here we reorder the operations a bit for better performance
function _affine_normalize(::typeof(identity), x::AbstractArray, xmean,
xvar, ::Nothing, ::Nothing, epsilon::Real)
return @. (x .- xmean) / sqrt(xvar + epsilon)
_scale = @. inv(sqrt(xvar + epsilon))
_bias = @. xmean * _scale
return @. x * _scale - _bias
end
function _affine_normalize(act::F, x::AbstractArray, xmean, xvar,
::Nothing, ::Nothing, epsilon::Real) where {F}
return @. act((x .- xmean) / sqrt(xvar + epsilon))
_scale = @. inv(sqrt(xvar + epsilon))
_bias = @. xmean * _scale
return @. act(x * _scale - _bias)
end

# Here we reorder the operations a bit for better performance
function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar,
scale::AbstractArray, bias::AbstractArray, epsilon::Real)
_scale = @. scale / sqrt(xvar + epsilon)
Expand Down
8 changes: 6 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ end
end
@inline function __fast_broadcast!(f::F, x, args...) where {F}
if ArrayInterface.fast_scalar_indexing(x)
@.. x = f(x, args...)
if maximum(length, (x, args...)) > 100_000
@.. thread=true x=f(x, args...)
else
@.. x = f(x, args...)
end
elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1
y = first(args)
@. x = sigmoid_fast(x + y) # Has GPU Compilation Problems
Expand All @@ -123,7 +127,7 @@ end
if ArrayInterface.fast_scalar_indexing(x)
if maximum(length, (x, args...)) > 100_000
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...))
@simd ivdep for I in eachindex(bc)
@batch for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
else
Expand Down
Loading