Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add one-arg method Duplicated(x) #2118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

mcabbott
Copy link
Contributor

@mcabbott mcabbott commented Nov 23, 2024

Instead of writing

x = some_constructor(params...)
dx = make_zero(x)
xdx = Duplicated(x, dx)

this wants to make it easy to just look after one thing,

xdx = Duplicated(some_constructor(params...))

and later get dx = xdx.dval out when you need it. Or not... add methods which consume the gradient like update!(::Duplicated) instead of unpacking.

Should there also be DuplicatedNoNeed(x), and perhaps BatchDuplicated(x, n::Int)? Or [edit] perhaps MixedDuplicated(x) is the closest other struct here.

@codecov-commenter
Copy link

codecov-commenter commented Nov 23, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 70.22%. Comparing base (037dfed) to head (9e26b67).
Report is 238 commits behind head on main.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2118      +/-   ##
==========================================
+ Coverage   67.50%   70.22%   +2.71%     
==========================================
  Files          31       42      +11     
  Lines       12668    15784    +3116     
==========================================
+ Hits         8552    11084    +2532     
- Misses       4116     4700     +584     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@MasonProtter
Copy link

My only concern about this is that it's rather reverse-mode centric. On the other hand, since there's no universally obvious candidate for what to put in there for forward mode, I guess it's fine to make this method favour reverse mode?

@mcabbott
Copy link
Contributor Author

That is true. I don't think there's a downside for forward mode, just no benefit.

Unless you think that for forward mode, something like Duplicated(x::Real) = Duplicated(x, oneunit(x)) would be sensible? That seems more surprising to me. The only precedent I can think of is this...

julia> DualNumbers.Dual(42.)
42.0 + 0.0ɛ

... where zero dual is a bit useless, but does construct the type you requested. And 42.0 + 1.0ɛ would I think be dangerous. Changing this would be clearly wrong:

julia> v = [DualNumbers.Dual(33.0, 1.0)];

julia> push!(v, 42)
2-element Vector{Dual128}:
 33.0 + 1.0ɛ
 42.0 + 0.0ɛ

@MasonProtter
Copy link

Yeah, I don't see any obvious use for it for enzyme so it's probably fine.

@lassepe
Copy link
Contributor

lassepe commented Dec 13, 2024

I feel like as long as Duplicated is user-facing in forward and reverse mode, this shorthand should not exist because it makes things more confusing. I find that the changing semantics of these types in forward vs reverse mode already contribute to a steep learning curve.

If we had different user-facing duplicted types for forward and reverse, e.g., Dual for forward and WithAccumulatedGradient for reverse mode, then each of them could have sane single-arg constructors. e.g.

  • Dual(x) = Duplicated(x, make_one(x)) (or even Dual(x::Vector) = BatchDuplicated(x, make_one_batch(x)...))
  • as well as WithAccumulatedGradient(x) = Duplicated(x, make_zero(x))

@mcabbott
Copy link
Contributor Author

Oh no, now I'm worried this PR will turn into a punching bag for API / documentation complaints.

From a human UI perspective, it does seem a bit odd to re-use the exact same struct for forward & reverse. Especially since the forward mode case where you might plausibly construct Dual(pi, 1) with your bare hands is precisely the case where Duplicated is useless, perhaps even misleading, in reverse mode:

julia> xdx = Duplicated(3.0, 100.0);

julia> autodiff(Forward, abs2, xdx)  # ok, like dual numbers
(600.0,)

julia> autodiff(Reverse, abs2, Active, xdx)  # maybe this should be an error?
((nothing,),)

julia> xdx  # Duplicated{Float64} is never a useful container for reverse mode
Duplicated{Float64}(3.0, 100.0)

julia> y = (array = [1.0, 2.0], float = 3.0);

julia> ydy = Duplicated(y, make_zero(y));

julia> autodiff(Reverse, y -> sum(y.array .* y.float), Active, ydy) 
((nothing,),)

julia> ydy.dval  # this isn't useless, but does require some understanding
(array = [3.0, 3.0], float = 0.0)

You need ydy = MixedDuplicated(y, Ref(make_zero(y))) to match the result of the friendly gradient(Reverse, y -> sum(y.array .* y.float), y) here.

Anyway, this PR changes nothing at all about forward mode.

Please make a separate issue to discuss changing from Duplicated to some new Dual for forward mode. (Or to discuss more clearly documenting things.)

The one thing that is on-topic for this PR is that we could make Duplicated(x) give a helpful error on x::Float64, maybe all isbits(x) or something. That's non-breaking, and would improve friendliness compared to the present methoderror. Duplicated(1.0, 0.0) is entirely useless for reverse mode (you may as well make Const(1.0)) as far as I can tell. And should occur only accidentally in forward mode (i.e. you get the dx from somewhere which could plausibly give you a nonzero value) hence you never want the one-argument constructor.

@wsmoses
Copy link
Member

wsmoses commented Dec 13, 2024

I'm very down to make/enforce the use of a separate Dual type for forward mode (which the Rust Enzyme stuff does), but yeah that's separate from here, if you want to open a different issue/PR on that.

Obviously that would be quite breaking [and need corresponding checks throughout enzyme].

Part of the reason for the one duplicated for both is that from the implementation (as opposed to the user side) they are very much the same, in the sense that a separate shadow data structure is created and maintained (hence the very literal name "duplicated" as in we duplicated the data structure from primal to shadow)

@MasonProtter
Copy link

MasonProtter commented Dec 13, 2024

Oh no, now I'm worried this PR will turn into a punching bag for API / documentation complaints.

Anyway, this PR changes nothing at all about forward mode.

Please make a separate issue to discuss changing from Duplicated to some new Dual for forward mode. (Or to discuss more clearly documenting things.)

I don't think that's a fair response. Like it or not, Duplicated is a fundamental part of the API for Forward mode, so IMO it is quite relevant to talk about how changes to Duplicated can make the (already highly confusing and second-class) Forward mode API feel even more confusing and second-class.

If we did disentangle the APIs so Forward mode didn't use Duplicated then I think this'd be a valid dismissal, but since they share it I think it's valid to worry about this making the API even more confusing / hard to teach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants