diff --git a/docs/src/faq.md b/docs/src/faq.md index 72fa1f97d93..5e57a8ada8a 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -540,3 +540,112 @@ For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \fr ``` Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space). + +## What types are differentiable? + +Enzyme tracks differentiable dataflow through values. Specifically Enzyme tracks differentiable data in base types like Float32, Float64, Float16, BFloat16, etc. + +As a simple example: + +```jldoctest types +f(x) = x * x +Enzyme.autodiff(Forward, f, Duplicated(3.0, 1.0)) + +# output + +(6.0,) +``` + +Enzyme also tracks differentiable data in any types containing these base types (e.g. floats). For example, consider a struct or array containing floats. + +```jldoctest types +struct Pair + lhs::Float64 + rhs::Float64 +end +f_pair(x) = x.lhs * x.rhs +Enzyme.autodiff(Forward, f_pair, Duplicated(Pair(3.0, 2.0), Pair(1.0, 0.0))) + +# output + +(2.0,) +``` + +```jldoctest types +Enzyme.autodiff(Forward, sum, Duplicated([1.0, 2.0, 3.0], [5.0, 0.0, 100.0])) + + +# output + +(105.0,) +``` + +A differentiable data structure can be arbitrarily complex, such as a linked list. + + +```jldoctest types + +struct LList + prev::Union{Nothing, LList} + value::Float64 +end + +function make_list(x::Vector) + result = nothing + for value in reverse(x) + result = LList(result, value) + end + return result +end + +function list_sum(list::Union{Nothing, LList}) + result = 0.0 + while list != nothing + result += list.value + list = list.prev + end + return result +end + +list = make_list([1.0, 2.0, 3.0]) +dlist = make_list([5.0, 0.0, 100.0]) + +Enzyme.autodiff(Forward, list_sum, Duplicated(list, dlist)) + +# output + +(105.0,) +``` + +Presently Enzyme only considers floats as base types. As a result, Enzyme does not support differentiating data contained in Ints, Strings, or Vals. If it is desirable for Enzyme to add a base type, please open an issue. + +```jldoctest types +f_int(x) = x * x +Enzyme.autodiff(Forward, f_int, DuplicatedNoNeed, Duplicated(3, 1)) + +# output + +ERROR: Return type `Int64` not marked Const, but type is guaranteed to be constant +``` + +```jldoctest types +f_str(x) = parse(Float64, x) * parse(Float64, x) + +autodiff(Forward, f_str, Duplicated("1.0", "1.0")) + +# output + +(0.0,) +``` + +```jldoctest types +f_val(::Val{x}) where x = x * x + +autodiff(Forward, f_val, Duplicated(Val(1.0), Val(1.0))) + +# output + +ERROR: Type of ghost or constant type Duplicated{Val{1.0}} is marked as differentiable. +``` + + diff --git a/src/compiler.jl b/src/compiler.jl index 9c780fa5980..25cfd4d4cd5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_try_substrtod", "jl_try_substrtod", "jl_f__apply_iterate", "ijl_field_index", "jl_field_index", "julia.call", "julia.call2", @@ -178,6 +179,7 @@ const nofreefns = Set{String}(( )) const inactivefns = Set{String}(( + "ijl_try_substrtod", "jl_try_substrtod", "ijl_tagged_gensym", "jl_tagged_gensym", "jl_get_world_counter", "ijl_get_world_counter", "memhash32_seed", "memhash_seed", @@ -2790,6 +2792,7 @@ function annotate!(mod, mode) "ijl_reshape_array", "jl_reshape_array", "ijl_eqtable_get", "jl_eqtable_get", "jl_gc_run_pending_finalizers", + "ijl_try_substrtod", "jl_try_substrtod", ) if haskey(fns, fname) fn = fns[fname]