From ea3246cd25a213dc4f8bfcdac0031f9cb5572381 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sun, 7 Jan 2024 20:15:49 -0800 Subject: [PATCH 1/2] Hotfix for new OneElement on GPU --- ext/FluxCUDAExt/functor.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index e8a89c5553..dc8649fff0 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -29,6 +29,10 @@ adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = # TODO: figure out the correct design for OneElement adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) +# Patch for GPU support until we can make OneElement smarter +if isdefined(Zygote.ChainRules, :OneElement) + adapt_storage(to::FluxCUDAAdaptor, x::Zygote.ChainRules.OneElement) = CUDA.cu(collect(x)) +end adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() From 53897c4700599c37fb2023cc3c7b1ac27997a3d1 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 8 Jan 2024 15:10:02 -0500 Subject: [PATCH 2/2] Remove deprecated env var and bump version --- .buildkite/pipeline.yml | 1 - Project.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5bb1dbfe3c..59538d1772 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -10,7 +10,6 @@ steps: queue: "juliagpu" cuda: "*" env: - JULIA_CUDA_USE_BINARYBUILDER: "true" FLUX_TEST_CUDA: "true" FLUX_TEST_CPU: "false" timeout_in_minutes: 60 diff --git a/Project.toml b/Project.toml index f40a8cdab7..9c332a8fc4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.8" +version = "0.14.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"