diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 77f4247..9abdb8c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: matrix: version: - '1.6' - - '1.8' + - '1.9' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 0fd92c2..0cd2931 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Yota" uuid = "cd998857-8626-517d-b929-70ad188a48f0" authors = ["Andrei Zhabinski "] -version = "0.8.2" +version = "0.8.3" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" @@ -20,5 +20,5 @@ ChainRules = "1.43" ChainRulesCore = "1.15" FiniteDifferences = "0.12" NNlib = "0.8" -Umlaut = "0.4.8" +Umlaut = "0.5.1" julia = "1.6" diff --git a/src/grad.jl b/src/grad.jl index 27509ec..3426f2d 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -10,7 +10,7 @@ function has_rrule(f, args...) F = Core.Typeof(f) Args = Core.Typeof.(args) Core.Compiler.return_type(rrule, Tuple{YotaRuleConfig, F, Args...}) !== Nothing && return true - if is_kwfunc(F) + if is_kwfunc(f) # must be: Tuple{Any, typeof(rrule), YotaRuleConfig, typeof(unkwfunc(f)), Args[3:end]...} nokw_f = unkwfunc(f, args...) Args_kwrrule = Tuple{Any, typeof(rrule), YotaRuleConfig, typeof(nokw_f), Args[3:end]...} @@ -132,8 +132,11 @@ with the following chain or calls: where `val = fn(args...)` and `pb` is the pullback function. """ function chainrules_transform!(tape::Tape) + # global TAPE = tape + # error("") i = 1 while i <= length(tape) + # tape[V(i)] isa Call && tape[V(i)].fn == Core.kwcall && break op = tape[V(i)] if op isa Call && isprimitive(ChainRulesCtx(), call_values(op)...) # replace f(args...) with rrule(f, args...) @@ -180,6 +183,7 @@ function step_back!(tape::Tape, y::Variable) end for (i, x) in enumerate(y_fargs) if x isa V + global STATE = (tape, y, y_fargs, i, x) dx = push!(tape, mkcall(getfield, dxs, i; line="d$y/d$x")) # @debug "Updating derivative: $x -> $dx" set_or_add_deriv!(tape, x, dx) diff --git a/src/utils.jl b/src/utils.jl index ecc6bc5..71df5d5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,9 @@ # not the most robust function, but works in practise -is_kwfunc(f) = (name = string(f); endswith(name, "##kw") || endswith(name, "##kw\"")) +if VERSION < v"1.9.0" + is_kwfunc(f) = (name = string(f); endswith(name, "##kw") || endswith(name, "##kw\"")) +else + is_kwfunc(f) = (f === Core.kwcall) +end is_kwfunc(v::Variable) = is_kwfunc(v._op.val) function unkwfunc(f, args...) diff --git a/test/Project.toml b/test/Project.toml index 68924a4..6fc48f5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -CUDA = "3" +CUDA = "3, 4" ChainRules = "1" ChainRulesCore = "1" ChainRulesTestUtils = "1" \ No newline at end of file