-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some tweaks to the Getting Started docs #2195
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,23 +31,25 @@ Also see [Implementing pullbacks](@ref) on how to implement back-propagation for | |||||||||||||||||||||||||||||
We will try a few things with the following functions: | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
```jldoctest rosenbrock | ||||||||||||||||||||||||||||||
julia> rosenbrock(x, y) = (1.0 - x)^2 + 100.0 * (y - x^2)^2 | ||||||||||||||||||||||||||||||
rosenbrock (generic function with 1 method) | ||||||||||||||||||||||||||||||
julia> rosenbrock(x, y) = (1.0 - x)^2 + 100.0 * (y - x^2)^2; | ||||||||||||||||||||||||||||||
julia> rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 | ||||||||||||||||||||||||||||||
rosenbrock_inp (generic function with 1 method) | ||||||||||||||||||||||||||||||
julia> rosenbrock_inp(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2; | ||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
where we note for future reference that the value of this function at `x=1.0`, `y=2.0` is `100.0`, and its derivative | ||||||||||||||||||||||||||||||
with respect to `x` at that point is `-400.0`, and its derivative with respect to `y` at that point is `200.0`. | ||||||||||||||||||||||||||||||
Comment on lines
+36
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I've also removed 100.0 from the definition, as IMO this is idiomatic Julia -- rosenbrock can take & return Float32 without promoting. |
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
## Reverse mode | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
The return value of reverse mode [`autodiff`](@ref) is a tuple that contains as a first value | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
the derivative value of the active inputs and optionally the primal return value. | ||||||||||||||||||||||||||||||
the derivative value of the active inputs and optionally the _primal_ return value (i.e. the | ||||||||||||||||||||||||||||||
value of the undifferentiated function). | ||||||||||||||||||||||||||||||
Comment on lines
44
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider not using "value" to mean so many things here?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe "optionally" also seems a bit odd to describe the output not the input. It's not that you may omit this. It's that ReverseWithPrimal tells it to. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we definitely don't need to say "derivative value" and can just say "the derivative of" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and "the value of the undifferentiated function" -> "the result of the original function without differentiation" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also consider putting the ReverseWithPrimal case first, as without it, Perhaps also write it with destructuring syntax, like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yeah totally fair, if you want to put that in this PR that would be fine with me! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't make suggestions across deleted lines :/ so this is going to be messy...
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
```jldoctest rosenbrock | ||||||||||||||||||||||||||||||
julia> autodiff(Reverse, rosenbrock, Active, Active(1.0), Active(2.0)) | ||||||||||||||||||||||||||||||
julia> autodiff(Reverse, rosenbrock, Active(1.0), Active(2.0)) | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
((-400.0, 200.0),) | ||||||||||||||||||||||||||||||
julia> autodiff(ReverseWithPrimal, rosenbrock, Active, Active(1.0), Active(2.0)) | ||||||||||||||||||||||||||||||
julia> autodiff(ReverseWithPrimal, rosenbrock, Active(1.0), Active(2.0)) | ||||||||||||||||||||||||||||||
((-400.0, 200.0), 100.0) | ||||||||||||||||||||||||||||||
Comment on lines
+52
to
53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -62,7 +64,7 @@ julia> dx = [0.0, 0.0] | |||||||||||||||||||||||||||||
0.0 | ||||||||||||||||||||||||||||||
0.0 | ||||||||||||||||||||||||||||||
julia> autodiff(Reverse, rosenbrock_inp, Active, Duplicated(x, dx)) | ||||||||||||||||||||||||||||||
julia> autodiff(Reverse, rosenbrock_inp, Duplicated(x, dx)) | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
((nothing,),) | ||||||||||||||||||||||||||||||
julia> dx | ||||||||||||||||||||||||||||||
|
@@ -71,8 +73,9 @@ julia> dx | |||||||||||||||||||||||||||||
200.0 | ||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Both the inplace and "normal" variant return the gradient. The difference is that with | ||||||||||||||||||||||||||||||
[`Active`](@ref) the gradient is returned and with [`Duplicated`](@ref) the gradient is accumulated in place. | ||||||||||||||||||||||||||||||
Both the inplace and "normal" variant return the gradient. The difference is that with [`Active`](@ref) | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This wording seems weird. The inplace version returns
IDK how much of the end goes here, but the reader should not get the impression that all arguments must have the same type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we should say both compute the gradient. And we can use whatever function names are here for clarity There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
the gradient is returned and with [`Duplicated`](@ref) the gradient is accumulated in-place into `dx`, | ||||||||||||||||||||||||||||||
and a value of `nothing` is placed in the corresponding slot of the returned `Tuple`. | ||||||||||||||||||||||||||||||
Comment on lines
+77
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
## Forward mode | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -121,7 +124,7 @@ julia> dx = [1.0, 1.0] | |||||||||||||||||||||||||||||
1.0 | ||||||||||||||||||||||||||||||
1.0 | ||||||||||||||||||||||||||||||
julia> autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, Duplicated(x, dx)) | ||||||||||||||||||||||||||||||
julia> autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated(x, dx)) | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
(-400.0, 400.0) | ||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider showing this as code instead of prose?
I also think you should not call the input of
rosenbrock_inp
the same thing,x == [x, y]
is weird. The namerosenbrock_inp
also seems a bit weird, maybe it can just be another method, or if that's too confusing, add a suffix more informative than "inp"? (I'm not sure what INP means, maybe input, but why?)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was originally in place (@michel2323 were you the one to originally author this doc, just by virtue of it being rosenbrock?)
But either way sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok. But this function isn't in-place, it's just going to be used somewhere below in a demonstration that Enzyme likes to handle functions which accept Vector by mutating something else. The reader doesn't know that yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very true, maybe rosenbrok_array or something? or even just rosenbrock2