diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index 81d8122b..d5c22729 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -3,7 +3,9 @@ module BijectorsEnzymeExt using Enzyme: @import_rrule, @import_frule, Enzyme, EnzymeCore using Bijectors: find_alpha, ChainRulesCore -@static if VERSION == v"1.11.1" +# This is solution (2) listed in https://github.com/TuringLang/Bijectors.jl/issues/339. +# Eventually we would like to move to solution (4) but we don't know when we can. +@static if VERSION >= v"1.11.1" # @import_rrule function (Enzyme.EnzymeRules).augmented_primal(var"#8#config", var"#9#fn"::var"#16#FA", ::Enzyme.Type{var"#15#RetAnnotation"}, var"#11#arg_1"::var"#17#AN_1", var"#12#arg_2"::var"#18#AN_2", var"#13#arg_3"::var"#19#AN_3"; var"#14#kwargs"...) where {var"#15#RetAnnotation", var"#16#FA" <: Enzyme.Annotation{<:typeof(find_alpha)}, var"#17#AN_1" <: Enzyme.Annotation{<:Real}, var"#18#AN_2" <: Enzyme.Annotation{<:Real}, var"#19#AN_3" <: Enzyme.Annotation{<:Real}} var"#1#primcopy_1" = if ((EnzymeCore.EnzymeRules.overwritten)(var"#8#config"))[1 + 1]