Skip to content

Commit

Permalink
Try to load CUDA libraries if they are available (instead of relying …
Browse files Browse the repository at this point in the history
…on Requires)
  • Loading branch information
dfdx committed Sep 22, 2019
1 parent bf0e4b9 commit 35d6831
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.2.0"

[deps]
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
12 changes: 11 additions & 1 deletion src/Yota.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
module Yota

export grad, update!, @diffrule, @diffrule_kw, @nodiff
export
grad,
update!,
@diffrule,
@diffrule_kw,
@nodiff,
best_available_device,
to_device,
CPU,
GPU


include("core.jl")

Expand Down
26 changes: 23 additions & 3 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Cassette: Tagged, tag, untag, istagged, metadata, hasmetadata,
enabletagging, @overdub, overdub, canrecurse, similarcontext, fallback
using JuliaInterpreter
using Espresso
using Requires
using CUDAapi


include("utils.jl")
Expand All @@ -21,6 +21,26 @@ include("update.jl")
include("transform.jl")


function __init__()
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda.jl")
# function __init__()
# @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda.jl")
# end

const BEST_AVAILABLE_DEVICE = Ref{AbstractDevice}(CPU())

if has_cuda()
try
using CuArrays
using CUDAnative

BEST_AVAILABLE_DEVICE[] = GPU(0)

include("cuda.jl")
catch ex
# something is wrong with the user's set-up (or there's a bug in CuArrays)
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())

end
end


best_available_device() = BEST_AVAILABLE_DEVICE[]

0 comments on commit 35d6831

Please sign in to comment.