Skip to content

Commit

Permalink
Merge pull request #46 from dfdx/cuda-dev
Browse files Browse the repository at this point in the history
Cuda dev
  • Loading branch information
dfdx authored Sep 22, 2019
2 parents f2f76e9 + 35d6831 commit f1f7f69
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 11 deletions.
193 changes: 193 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,55 @@
# This file is machine-generated - editing it directly is not advised

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.4.1"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.0.0"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BinaryProvider]]
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.6"

[[CEnum]]
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.2.0"

[[CSTParser]]
deps = ["Tokenize"]
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.2"

[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "1.2.0"

[[CUDAdrv]]
deps = ["CUDAapi", "Libdl", "Printf"]
git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "3.1.0"

[[CUDAnative]]
deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
git-tree-sha1 = "52ae1ce10ebfa686e227655c47b19add89308623"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.3.1"

[[Cassette]]
git-tree-sha1 = "da85d135b6048d3e78603e277cf9a4609f7e0673"
uuid = "7057c7e9-c182-5462-911a-8362d720325c"
Expand All @@ -14,6 +61,44 @@ git-tree-sha1 = "0becdab7e6fbbcb7b88d8de5b72e5bb2f28239f3"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "0.5.8"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"

[[Conda]]
deps = ["JSON", "VersionParsing"]
git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.3.0"

[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.2.1"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -26,16 +111,49 @@ repo-url = "https://github.com/dfdx/Espresso.jl.git"
uuid = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
version = "0.6.0"

[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.0.1"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.6.4"

[[GPUArrays]]
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
git-tree-sha1 = "77e27264276fe97a7e7fb928bf8999a145abc018"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "1.0.3"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.0"

[[JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "5020abc08c5c9f7d47ec7e861309bc79ed74aec7"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.7.2"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "4a05f742837779a00bd8c9a18da6817367c4245d"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.3.0"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -46,14 +164,59 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.1"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.0"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.7"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"

[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
Expand All @@ -66,13 +229,23 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.11.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -81,6 +254,26 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.6"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[VersionParsing]]
deps = ["Compat"]
git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669"
uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
version = "1.1.3"
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@ authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.2.0"

[deps]
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1"
Cassette = "0.2.6"
Espresso = "0.6.0"
Espresso = "0.6.0"
julia = "1"
12 changes: 11 additions & 1 deletion src/Yota.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
module Yota

export grad, update!, @diffrule, @diffrule_kw, @nodiff
export
grad,
update!,
@diffrule,
@diffrule_kw,
@nodiff,
best_available_device,
to_device,
CPU,
GPU


include("core.jl")

Expand Down
26 changes: 23 additions & 3 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Cassette: Tagged, tag, untag, istagged, metadata, hasmetadata,
enabletagging, @overdub, overdub, canrecurse, similarcontext, fallback
using JuliaInterpreter
using Espresso
using Requires
using CUDAapi


include("utils.jl")
Expand All @@ -21,6 +21,26 @@ include("update.jl")
include("transform.jl")


function __init__()
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda.jl")
# function __init__()
# @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda.jl")
# end

const BEST_AVAILABLE_DEVICE = Ref{AbstractDevice}(CPU())

if has_cuda()
try
using CuArrays
using CUDAnative

BEST_AVAILABLE_DEVICE[] = GPU(0)

include("cuda.jl")
catch ex
# something is wrong with the user's set-up (or there's a bug in CuArrays)
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())

end
end


best_available_device() = BEST_AVAILABLE_DEVICE[]
6 changes: 2 additions & 4 deletions src/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ CUDANATIVE_OPS[^] = CUDAnative.pow
CUDANATIVE_OPS[ones] = CUDAnative.ones

device_function(device::GPU, f::Function) = get(CUDANATIVE_OPS, f, f)
# device_function(::GPU, f::Function) = CuArrays.cufunc(f)
to_device(device::GPU, x) = cu(x)


function to_cuda(x)
function to_device(device::GPU, x)
T = typeof(x)
flds = fieldnames(T)
if is_cuarray(x)
Expand All @@ -36,7 +34,7 @@ function to_cuda(x)
return cu(x)
else
# struct, recursively convert and construct type from fields
fld_vals = [to_cuda(getfield(x, fld)) for fld in flds]
fld_vals = [to_device(device, getfield(x, fld)) for fld in flds]
return T(fld_vals...)
end
end
Expand Down

0 comments on commit f1f7f69

Please sign in to comment.