Skip to content

Commit

Permalink
Docs: describe differentiable types (EnzymeAD#1433)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored May 12, 2024
1 parent 19dbbb2 commit 75e5311
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
109 changes: 109 additions & 0 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,112 @@ For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \fr
```

Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space).

## What types are differentiable?

Enzyme tracks differentiable dataflow through values. Specifically Enzyme tracks differentiable data in base types like Float32, Float64, Float16, BFloat16, etc.

As a simple example:

```jldoctest types
f(x) = x * x
Enzyme.autodiff(Forward, f, Duplicated(3.0, 1.0))
# output
(6.0,)
```

Enzyme also tracks differentiable data in any types containing these base types (e.g. floats). For example, consider a struct or array containing floats.

```jldoctest types
struct Pair
lhs::Float64
rhs::Float64
end
f_pair(x) = x.lhs * x.rhs
Enzyme.autodiff(Forward, f_pair, Duplicated(Pair(3.0, 2.0), Pair(1.0, 0.0)))
# output
(2.0,)
```

```jldoctest types
Enzyme.autodiff(Forward, sum, Duplicated([1.0, 2.0, 3.0], [5.0, 0.0, 100.0]))
# output
(105.0,)
```

A differentiable data structure can be arbitrarily complex, such as a linked list.


```jldoctest types
struct LList
prev::Union{Nothing, LList}
value::Float64
end
function make_list(x::Vector)
result = nothing
for value in reverse(x)
result = LList(result, value)
end
return result
end
function list_sum(list::Union{Nothing, LList})
result = 0.0
while list != nothing
result += list.value
list = list.prev
end
return result
end
list = make_list([1.0, 2.0, 3.0])
dlist = make_list([5.0, 0.0, 100.0])
Enzyme.autodiff(Forward, list_sum, Duplicated(list, dlist))
# output
(105.0,)
```

Presently Enzyme only considers floats as base types. As a result, Enzyme does not support differentiating data contained in Ints, Strings, or Vals. If it is desirable for Enzyme to add a base type, please open an issue.

```jldoctest types
f_int(x) = x * x
Enzyme.autodiff(Forward, f_int, DuplicatedNoNeed, Duplicated(3, 1))
# output
ERROR: Return type `Int64` not marked Const, but type is guaranteed to be constant
```

```jldoctest types
f_str(x) = parse(Float64, x) * parse(Float64, x)
autodiff(Forward, f_str, Duplicated("1.0", "1.0"))
# output
(0.0,)
```

```jldoctest types
f_val(::Val{x}) where x = x * x
autodiff(Forward, f_val, Duplicated(Val(1.0), Val(1.0)))
# output
ERROR: Type of ghost or constant type Duplicated{Val{1.0}} is marked as differentiable.
```


3 changes: 3 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}(
end

const nofreefns = Set{String}((
"ijl_try_substrtod", "jl_try_substrtod",
"jl_f__apply_iterate",
"ijl_field_index", "jl_field_index",
"julia.call", "julia.call2",
Expand Down Expand Up @@ -178,6 +179,7 @@ const nofreefns = Set{String}((
))

const inactivefns = Set{String}((
"ijl_try_substrtod", "jl_try_substrtod",
"ijl_tagged_gensym", "jl_tagged_gensym",
"jl_get_world_counter", "ijl_get_world_counter",
"memhash32_seed", "memhash_seed",
Expand Down Expand Up @@ -2790,6 +2792,7 @@ function annotate!(mod, mode)
"ijl_reshape_array", "jl_reshape_array",
"ijl_eqtable_get", "jl_eqtable_get",
"jl_gc_run_pending_finalizers",
"ijl_try_substrtod", "jl_try_substrtod",
)
if haskey(fns, fname)
fn = fns[fname]
Expand Down

0 comments on commit 75e5311

Please sign in to comment.