Skip to content

Commit

Permalink
Merge pull request #75 from dfdx/more-consistency
Browse files Browse the repository at this point in the history
More consistency
  • Loading branch information
dfdx authored Jan 23, 2021
2 parents 50b2817 + bc633b6 commit 8f5f274
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 386 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.4.3"
version = "0.4.4"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Espresso = "6912e4f1-e036-58b0-9138-08d1e6358ea9"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CUDA = "1.2, 2.3"
Cassette = "0.2.6, 0.3"
ChainRulesCore = "0.9.5"
Distributions = "0.23.2"
Espresso = "0.6.0"
IRTools = "0.4.0"
JuliaInterpreter = "0.7.2"
julia = "1.4"
7 changes: 0 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,6 @@ compile!(tape)
# 492.063 ns (2 allocations: 144 bytes)
```

Note that `trace()` is an alias to `irtrace()` - IRTools-based tracer. As of Yota 0.4.0, two other tracers are available:

* `ctrace()`, based on [Cassette.jl](https://github.com/jrevels/Cassette.jl)
* `itrace()`, based on [JuliaInterpreter.jl](https://github.com/JuliaDebug/JuliaInterpreter.jl)

These tracers can be used for experimental purposes, but **their reliability or even existence is not guaranteed in future**. For any long-term code please use alias `trace()` which always points to the most recent and well-tested implementation.

## CUDA support

`CuArray` is fully supported. If you encounter an issue with CUDA arrays which you don't have with ordinary arrays, please file a bug.
Expand Down
2 changes: 1 addition & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include("helpers.jl")
include("devices.jl")
include("tape.jl")
include("tapeutils.jl")
include("trace/trace.jl")
include("trace.jl")
include("diffrules/diffrules.jl")
include("grad.jl")
include("compile.jl")
Expand Down
21 changes: 5 additions & 16 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,14 @@ end

GPU() = GPU(1)


"""
Check if the argument is of type CuArray. Doesn't require CuArrays.jl to be loaded
"""
is_cuarray(x) = startswith(string(typeof(x)), "CuArray")

# function has_cuda_inputs(tape::Tape)
# res = false
# for op in tape
# if op isa Input && op.val isa CuArray
# res = true
# break
# end
# end
# return res
# end
is_cuarray(x) = x isa CuArray


# currently GPU's ID is just a placeholder
guess_device(args) = any(is_cuarray, args) ? GPU(1) : CPU()
device_of(A) = A isa CuArray ? GPU(1) : CPU()


"""
Retrieve function compatible with specified device
Expand All @@ -56,3 +42,6 @@ to_device(device::CPU, f::Function, args) = f

(device::CPU)(x) = to_device(device, x)
(device::GPU)(x) = to_device(device, x)


to_same_device(A, example) = device_of(example)(A)
12 changes: 7 additions & 5 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ function unbroadcast_prod_x(x::AbstractArray, y::AbstractArray, Δ)
end
unbroadcast_prod_y(x::AbstractArray, y::AbstractArray, Δ) = unbroadcast_prod_x(y, x, Δ)

device_like(example, a) = (device = guess_device([example]); device(a))
unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1]
unbroadcast_prod_x(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_x(x, device_like(x, [y]), Δ)
unbroadcast_prod_y(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_y(x, device_like(x, [y]), Δ)[1]
unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y(device_like(y, [x]), y, Δ)
# device_like(example, a) = (device = guess_device([example]); device(a))

# unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1]
unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(to_same_device([x], y), y, Δ)[1]
unbroadcast_prod_x(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_x(x, to_same_device([y], x), Δ)
unbroadcast_prod_y(x::AbstractArray, y::Number, Δ) = unbroadcast_prod_y(x, to_same_device([y], x), Δ)[1]
unbroadcast_prod_y(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_y(to_same_device([x], y), y, Δ)


untranspose_vec(ds::Transpose{T, <:AbstractVector{T}}) where T = transpose(ds)
Expand Down
45 changes: 45 additions & 0 deletions src/trace/irtools.jl → src/trace.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,45 @@
function __new__(T, args...)
# note: we also add __new__() to the list of primitives so it's not overdubbed recursively
if T <: NamedTuple
return T(args)
else
return T(args...)
end
end


__tuple__(args...) = tuple(args...)
__getfield__(args...) = getfield(args...)


function module_functions(modl)
res = Vector{Function}()
for s in Base.names(modl; all=true)
isdefined(modl, s) || continue
fn = getfield(modl, s)
if fn isa Function # && match(r"^[a-z#]+$", string(s)) != nothing
push!(res, fn)
end
end
return res
end

const PRIMITIVES = Set{Any}(vcat(
module_functions(Base),
module_functions(Core),
module_functions(Core.Intrinsics),
[Broadcast.materialize, Broadcast.broadcasted, Colon(), (:),
Base.not_int,
# our own special functions
__new__, __tuple__, __getfield__, namedtuple, guess_device]));


################################################################################
################################################################################
# IRTools-based Tracer #
################################################################################
################################################################################

import IRTools
import IRTools: IR, @dynamo, self, insertafter!

Expand Down Expand Up @@ -204,3 +246,6 @@ function irtrace(f, args...; primitives=PRIMITIVES, optimize=true)
end
return val, tape
end


trace = irtrace
118 changes: 0 additions & 118 deletions src/trace/cassette.jl

This file was deleted.

Loading

0 comments on commit 8f5f274

Please sign in to comment.