Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JVP, HVP [in progress] #167

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
156 changes: 147 additions & 9 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,14 @@
(str "Selectors " selectors
" not allowed for non-structural input " input)))))))

(defn multi [op f]
(fn
([] 0)
([x] ((op f) x))
([x & more]
((multi op (fn [xs] (apply f xs)))
(matrix/seq-> (cons x more))))))

(defn- multivariate
"Slightly wider version of [[euclidean]]. Accepts:

Expand All @@ -395,15 +403,145 @@
Single-argument functions don't transform their arguments."
([f] (multivariate f []))
([f selectors]
(let [d #(euclidean % selectors)
df (d f)
df* (d (fn [args] (apply f args)))]
(-> (fn
([] 0)
([x] (df x))
([x & more]
(df* (matrix/seq-> (cons x more)))))
(f/with-arity (f/arity f) {:from ::multivariate})))))
(let [d #(euclidean % selectors)]
(multi d f))))

(defn simple-jvp [f v]
(fn [x]
(let [g (fn [r]
(f (g/+ x (g/* r v))))]
((derivative g) 0))))

Check warning on line 413 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L410-L413

Added lines #L410 - L413 were not covered by tests

;; TODO do we like this, with the multi combinator?
(defn grad [f selectors]
(fn [x]
(when (and (seq selectors) (not (s/structure? x)))
(u/illegal
(str "Selectors " selectors
" not allowed for non-structural input " x)))

Check warning on line 421 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L419-L421

Added lines #L419 - L421 were not covered by tests

(let [tag (d/fresh-tag)
inputs (if (empty? selectors)
(tape/tapify x tag)
(update-in x selectors tape/tapify tag))

Check warning on line 426 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L426

Added line #L426 was not covered by tests
output (d/with-active-tag tag f [inputs])
completed (tape/->partials output tag)]
(if (empty? selectors)
(tape/interpret inputs completed tag)
(tape/interpret (get-in inputs selectors) completed tag)))))

Check warning on line 431 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L431

Added line #L431 was not covered by tests

(defn gradient
([f] (gradient f []))
([f selectors]
(multi #(grad % selectors) f)))

(defn jvp [f v]
(multi #(simple-jvp % v) f))

Check warning on line 439 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L439

Added line #L439 was not covered by tests

(defn vjp [f v]
(let [g (fn [x]
(if (or (v/scalar? x)
(v/scalar? v)
(s/compatible-for-contraction? x v))
(g/* x v)
(u/illegal "Incompatible structures!")))]

Check warning on line 447 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L447

Added line #L447 was not covered by tests
(gradient (comp g f))))

;; TODO this would be better, if we had a clear way of pulling the primals out
;; from the IPerturbed protocol... that would let `jvp` do its thing as well.

(defn primal-and-derivative [f]
(fn [x]
(let [tag (d/fresh-tag)
lifted (d/bundle-element x 1 tag)
output (d/with-active-tag tag f [lifted])]
[(d/primal output)
(d/extract-tangent output tag)])))

Check warning on line 459 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L454-L459

Added lines #L454 - L459 were not covered by tests

(defn primal-and-jvp [f v]
(fn [x]
(let [g (fn [r]
(f (g/+ x (g/* r v))))]
((primal-and-derivative g) 0))))

Check warning on line 465 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L462-L465

Added lines #L462 - L465 were not covered by tests

(defn vjp* [f]
(fn [x]
(let [tag (d/fresh-tag)
inputs (tape/tapify x tag)
output (d/with-active-tag tag f [inputs])]
[(tape/tape-primal output)
(fn [v]
(let [output (if (or (v/scalar? x)
(v/scalar? v)
(s/compatible-for-contraction? x v))
(g/* output v)
(u/illegal "Incompatible structures!"))
completed (tape/->partials output tag)]
(tape/interpret inputs completed tag)))])))

Check warning on line 480 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L468-L480

Added lines #L468 - L480 were not covered by tests

;; attempting to return the primal too

(defn primal-and-vjp [f]
(multi vjp* f))

Check warning on line 485 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L485

Added line #L485 was not covered by tests

(defn hvp [f v]
(jvp (gradient f) v))

Check warning on line 488 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L488

Added line #L488 was not covered by tests

(comment
(require 'emmy.env)
(let [f (emmy.env/literal-function 'f (-> (UP* Real 10) Real))
x (s/literal-up 'x 10)
v (s/literal-up 'dx 10)]
(g/- ((hvp f v) x)
(g/* (((g/square D) f) x) v)))

(let [f (emmy.env/literal-function 'f (-> (UP Real Real) (UP Real Real)))
x (s/up 'x 'y)
v (s/up 'dx 'dy)]

(g/- ((jvp f v) x)
(g/* ((D f) x) v))))

#_
(comment
(let [f (fn [[a b c d e f g h i :as x]]
(into []
(take 1000
(cycle [(g/expt (g/cos a) (g/atan c))
(g/expt (g/sin a) (g/atan b))
(g/expt (g/sin b) (g/atan c))
(g/exp (g/square c))
(g/expt (g/cos d) (g/atan f))
(g/expt (g/sin d) (g/atan e))
(g/expt (g/sin e) (g/atan f))
(g/exp (g/square f))
(g/expt (g/cos g) (g/atan i))
(g/expt (g/sin g) (g/atan h))
(g/expt (g/sin h) (g/atan i))
(g/exp (g/square i))]))))
v (s/up 'da 'db 'dc 'de 'de 'df 'dg 'dh 'di)
x (s/up 'a 'b 'c 'd 'e 'f 'g 'h 'i)
n 1]
(time (dotimes [_ n] ((jvp f v) x)))
(time (dotimes [_ n] (g/* ((D f) x) v))))

(let [f (fn [[x y z]]
[(g/expt (g/cos x) (g/atan z))
(g/expt (g/sin x) (g/atan y))
(g/expt (g/sin y) (g/atan z))
(g/exp (g/square z))])
v (s/up 'dx 'dx 'dz)
x (s/up 'x 'y 'z)]
(g/- ((jvp f v) x)
(g/* ((D f) x) v))))

#_
(extend-protocol v/Numerical
nil
(numerical? [_] false))

;; The result of applying the derivative `(D f)` of a multivariable function `f`


;; ## Generic [[g/partial-derivative]] Installation
;;
Expand Down
8 changes: 8 additions & 0 deletions test/emmy/calculus/derivative_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -1617,3 +1617,11 @@
(take 2)))
"symbolic-taylor-series keeps the arguments symbolic, even when they
are numbers."))))

(deftest vjp-test
(let [f (fn [[x y]] [(g/* x y) (g/sin x) (g/cos y)])
v (s/down 'dx 'dy 'dz)]
(is (g/zero?
(g/simplify
(g/- ((d/vjp f v) (s/up 'x 'y))
(g/* v ((D f) (s/up 'x 'y)))))))))
Loading