-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A unified derivative operator #973
Comments
You also need to add support for doing the cross/dot product with these, and likely for arrays in general |
@shashi @YingboMa now would be a good time for us to talk about tensor derivatives. It might give you ideas for how to structure your new tensor derivative API. I've prepared a few pages describing how this works but before we talked I wanted to add one more example showing convolution, and another section explaining why the method works. However, this has already taken me too long - tracking down bugs in FastDifferentiation has chewed up all my available discretionary time. We could get started on tensor derivatives with what exists now and I can expand it over time. Are you both available this week? |
I am flexible after 11 am EST this week. |
Yes, let's talk soon. I'll message on slack. I can also help you debug the |
Now that we are thinking about matrix calculus and symbolic array functions [fn2], I thought it is time to finally come up with a unified derivative operator/API. Here I'm going to use "derivative" to mean scalar derivatives, gradients, jacobians--something that's a Fréchet derivative in general.
The general API I'm thinking of is:
∂(x)
is the "derivative with respect tox
" operator,x
is not restricted to being a Real symbol but can be an array of symbols, or a symbol that represents an array, or a nesting of these. See [fn3] for why ∂∂(x)(f(x))
is the derivative linear map, with an added syntax-regularizing rule that it must support right-multiplication with an element from the vector space thatx
is from. This allows us to return3
as the derivative of3x
, and also return a function-likeLinOp(Δ ↦ A*Δ + Δ*A)
which applies the functionΔ ↦ A*Δ + Δ*A
by treating its right multiplicand asΔ
. In the below examples,x
andy
are scalars,u
is a vector andA
is a square matrix.∂(x)(3x) = 3
. -- scalar derivative∂([x, y])(x + y) = [1 1]
-- gradient∂(u)(A*u) = A
-- jacobian∂(A)(A*A) = LinOp(Δ ↦ A*Δ + Δ*A)
which is some object that substitutes the multiplicand forΔ
on right-multiplication. One possible variation on this is in [fn1].∂(..)
has some recursive structure, for example--∂([A, u])(A*u) = [∂(A)(A*u) A]
this object can be right-multiplied with a vector of elements of the same shape as[A, u]
An cool consequence is
∂(x)(f(x)) * ..
is the "jvp", for example,LinOp(Δ ↦ A*Δ + Δ*A)
would be the jvp-calculating function. Gaurav used an example of a functionx -> x[1] + x[4]
whose derivative isLinOp(Δ ↦ Δ[1] + Δ[4])
which would have a smaller memory footprint than the gradient ifx
is a million elements (assuming ∂ could choose to get the LinOp instead of gradient). It's possible to make it so that* I
"materializes" a linear operator in general, whenever possible. @YingboMa this is the API we thought of for ForwardDiff2. (TODO: think about how*
would work to do chain rule in the case of LinOp).Present state of affairs: We have
Differential
(from pre-Symbolics days) which represents scalar differentiation, we don't really represent derivatives of other kinds like gradients, jacobians etc in the expression trees, soexpand_derivatives
only expands scalar derivatives. I thinkexpand_derivatives
should be part ofsimplify
. Just to be clear, the currentderivative
,gradient
andjacobian
functions will continue to work as they do now.Higher-order notation could be:
fn1: Lambdas: Gaurav suggested there could be an object that behaves like a "hole" or the right multiplicand -- like
⍵
in APL, I think that should be syntactic sugar for lambda. I've been thinking of adding lambdas for a while, and I think it's going to be interesting to have lambdas with unbound closed symbolic variables in them that take part in partial evaluation and get bound later if and only when the expression is compiled into code.fn2: Array function registration (#292 #753 etc) needed some syntax, and the ∂ scheme comes in handy in allowing the definition of various derivatives while registering a function to be a symbolic primitive. This would look something like:
We can just add the derivative rules as rewrite rules into
expand_derivatives
(in spirit, but we can also do something faster in practice.)cc @alanedelman @stevengj @YingboMa @Roni-Edwin
@brianguenter I didn't think too much about how this would play with FastDifferentiation... We can talk more over slack.
fn3: I'm sure most of us would prefer
∂
is the exported symbol as opposed toD
. It makes expressions nicely readable.thanks to @gaurav-arya and @avik-pal for helping me brainstorm this. And to scmutils for allowing some experiments.
The text was updated successfully, but these errors were encountered: