From 9e26b67bdcedf7a64417aad2dab00ae5b95673ac Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 23 Nov 2024 18:15:54 -0500 Subject: [PATCH] add one-arg Duplicated(x) --- lib/EnzymeCore/src/EnzymeCore.jl | 5 ++++- lib/EnzymeCore/test/misc.jl | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 614bedbe55..0c8cc465c9 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -55,12 +55,13 @@ Active(i::Integer) = Active(float(i)) Active(ci::Complex{T}) where T <: Integer = Active(float(ci)) """ - Duplicated(x, ∂f_∂x) + Duplicated(x, ∂f_∂x = make_zero(x)) Mark a function argument `x` of [`autodiff`](@ref Enzyme.autodiff) as duplicated, Enzyme will auto-differentiate in respect to such arguments, with `dx` acting as an accumulator for gradients (so ``\\partial f / \\partial x`` will be *added to*) `∂f_∂x`. +If the second argument is not provided, it is created by [`make_zero`](@ref). """ struct Duplicated{T} <: Annotation{T} val::T @@ -76,6 +77,8 @@ struct Duplicated{T} <: Annotation{T} end end +Duplicated(x) = Duplicated(x, make_zero(x), false) + """ DuplicatedNoNeed(x, ∂f_∂x) diff --git a/lib/EnzymeCore/test/misc.jl b/lib/EnzymeCore/test/misc.jl index 3c24ddc7c3..f55217064d 100644 --- a/lib/EnzymeCore/test/misc.jl +++ b/lib/EnzymeCore/test/misc.jl @@ -23,5 +23,8 @@ d = @view data[2:end] y = @view data[3:end] @test_skip @test_throws AssertionError Duplicated(d, y) +dup = Duplicated(data) +@test dup isa Duplicated + @test_throws ErrorException Active(data) @test_skip @test_throws ErrorException Active(d)