Skip to content

Commit

Permalink
docs: keep the ConvMixer default backend as cuda.jl for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2024
1 parent 424663a commit 442bb41
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end
Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
backend::String="gpu_if_available")
rng = StableRNG(seed)

if backend == "gpu_if_available"
Expand Down
3 changes: 3 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x)
function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber)
return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y))
end

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
(g::Lux.GlobalPoolMode)(::TracedRArray) = g

Check warning on line 10 in ext/LuxReactantExt/patches.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt/patches.jl#L10

Added line #L10 was not covered by tests
3 changes: 1 addition & 2 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ end

struct GlobalPoolMode <: AbstractPoolMode end

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
(::GlobalPoolMode)() = GlobalPoolMode()
(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)])

@concrete struct AdaptivePoolMode <: AbstractPoolMode
out_size <: Tuple{Vararg{IntegerType}}
Expand Down

0 comments on commit 442bb41

Please sign in to comment.