Skip to content

Commit

Permalink
test: reenable flux testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 5, 2024
1 parent ef0d450 commit 6089c4a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.4.0"
version = "1.4.1-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -79,7 +79,7 @@ DispatchDoctor = "0.4.12"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
FastClosures = "0.3.2"
Flux = "0.14.25"
Flux = "0.15"
ForwardDiff = "0.10.36"
FunctionWrappers = "1.1.3"
Functors = "0.5"
Expand Down
8 changes: 4 additions & 4 deletions ext/LuxFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ function Lux.convert_flux_model(
return Lux.GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ)
end

const _INVALID_TRANSFORMATION_TYPES = Union{<:Flux.Recur}
# const _INVALID_TRANSFORMATION_TYPES = Union{}

function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES}
throw(FluxModelConversionException("Transformation of type $(T) is not supported."))
end
# function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES}
# throw(FluxModelConversionException("Transformation of type $(T) is not supported."))
# end

for cell in (:RNNCell, :LSTMCell, :GRUCell)
msg = "Recurrent Cell: $(cell) for Flux has semantical difference with Lux, \
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ if ("all" in LUX_TEST_GROUP || "misc" in LUX_TEST_GROUP)
push!(EXTRA_PKGS, Pkg.PackageSpec("MPI"))
(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
push!(EXTRA_PKGS, Pkg.PackageSpec("NCCL"))
# XXX: Reactivate once Flux is compatible with Functors 0.5
# push!(EXTRA_PKGS, Pkg.PackageSpec("Flux"))
push!(EXTRA_PKGS, Pkg.PackageSpec("Flux"))
end

if !Sys.iswindows()
Expand Down
2 changes: 1 addition & 1 deletion test/transform/flux_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] skip=:(true) begin
@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] begin
import Flux

toluxpsst = FromFluxAdaptor(; preserve_ps_st=true)
Expand Down

0 comments on commit 6089c4a

Please sign in to comment.