-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path4_autodiff.jl
2918 lines (2413 loc) · 102 KB
/
4_autodiff.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
### A Pluto.jl notebook ###
# v0.20.0
using Markdown
using InteractiveUtils
# This Pluto notebook uses @bind for interactivity. When running this notebook outside of Pluto, the following 'mock version' of @bind gives bound variables a default value (instead of an error).
macro bind(def, element)
quote
local iv = try Base.loaded_modules[Base.PkgId(Base.UUID("6e696c72-6542-2067-7265-42206c756150"), "AbstractPlutoDingetjes")].Bonds.initial_value catch; b -> missing; end
local el = $(esc(element))
global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : iv(el)
el
end
end
# ╔═╡ 9f027cde-dba0-4da5-8c42-5fa79b3929d6
using Graphs, GraphPlot, Printf
# ╔═╡ f1ba3d3c-d0a5-4290-ab73-9ce34bd5e5f6
using Plots, OneHotArrays, PlutoUI
# ╔═╡ 77a7de14-87d2-11ef-21ef-937b8239db5b
md"""
# Automatic Differentiation
On peut calculer des dérivées partielles de différentes manières:
1. De façon symbolique, en fixant une des variables et en dérivant les autres soit à la main, soit par ordinateur.
2. De façon numérique, avec la formule ``f'(x) \approx (f(x + h) - f(x)) / h``.
3. De façon algorithmique, soit forward, soit reverse, c'est ce que nous verons ici.
Pour illustrer, nous utiliserons l'exemple de classification de points de deux formes de lunes.
"""
# ╔═╡ 677a40de-ef6d-4a41-84f6-05ef6580aeba
md"Nous travaillerons avec une matrice `X` contenant dans chaque ligne, les coordonnées d'un point."
# ╔═╡ 0325e4a5-f50d-4064-b558-9f6275d4cd5a
md"Le vecteur `y` contiendra `1` pour les points de la lune bleue et `-1` pour les points de la lune rouge."
# ╔═╡ 860acc4a-f9ee-49c1-a6d1-d81a3c51d9a8
md"Nous illustrons le calcul de dérivée automatique par l'entrainement du modèle linéaire ``y \approx X w``. Commençons avec des poids aléatoires. Le modèle n'est pour le moment pas très précis comme il est aléatoire."
# ╔═╡ 55a0a91d-ebe4-4d8a-a094-b6e0aeec4587
w = rand(2)
# ╔═╡ c5b0cb9c-40be-44f6-9173-5a0631ab8834
md"En effet, les prédictions ne correspondent pas à `y`."
# ╔═╡ b4d2b635-bbb3-4024-877b-a96fdd19349e
md"On peut regrouper les erreurs des estimations de tous les points en les comparant avec `y`."
# ╔═╡ af404768-0663-4bc3-81dd-6931b3a486be
md"Essayons de trouver des poids `w` qui minimisent la somme des carrés des erreurs (aka MSE):"
# ╔═╡ 277bd2ce-fa7f-4288-be8a-0ddd8f23635c
md"""
## Forward Differentiation
Commençons par définir la forward differentiation. Cette différentiation algorithmique se base sur l'observation que la chain rule permet de calculer la dérivée de n'importe quelle fonction dès lors qu'on connait sont gradient et la dérivée de chacun de ses paramètres.
En d'autres mots, supposons qu'on doive calculer
```math
\frac{\partial}{\partial x} f(g(x), h(x))
```
Supposons que la fonction `f` soit une fonction `f(a, b)` simple (telle que `+`, `*`, `-`) dont on connait la formules des dérivée partielles ``\partial f / \partial a`` en fonction de `a` et ``\partial f / \partial b`` en fonction de `b`:
La chain rule nous donne
```math
\frac{\partial}{\partial x} f(g(x), h(x)) = \frac{\partial f}{\partial a}(g(x), h(x)) \frac{\partial g}{\partial x} + \frac{\partial f}{\partial b}(g(x), h(x)) \frac{\partial h}{\partial x}
```
Pour calculer cette expression, ils nous faut les valeurs de ``g(x)`` et ``h(x)`` ainsi que les dérivées ``\partial g / \partial x`` et ``\partial h / \partial x``.
"""
# ╔═╡ 94f2f9ef-9467-4781-9dfb-f0a32141f542
begin
struct Dual{T}
value::T
derivative::T
end
Dual(x, y) = Dual{typeof(x)}(x, convert(typeof(x), y))
end
# ╔═╡ e001e562-901f-4afa-a5ad-9bbecdae1694
md"L'implémentation générique du produit de matrice va appeler `zero`:"
# ╔═╡ ea2a923b-df68-4cb8-a3ff-62b0aadcc4f2
Base.zero(::Dual{T}) where {T} = Dual(zero(T), zero(T))
# ╔═╡ e8d8219d-1119-4f81-bc85-b27e33383fff
md"Par linéarité de la dérivée:"
# ╔═╡ 82ccdf44-5c45-4d55-ac1d-f4ec0a146b29
begin
Base.:*(α::T, x::Dual{T}) where {T} = Dual(α * x.value, α * x.derivative)
Base.:*(x::Dual{T}, α::T) where {T} = Dual(x.value * α, x.derivative * α)
Base.:+(x::Dual{T}, y::Dual{T}) where {T} = Dual(x.value + y.value, x.derivative + y.derivative)
Base.:-(x::Dual{T}, y::T) where {T} = Dual(x.value - y, x.derivative)
Base.:/(x::Dual, α::Number) = Dual(x.value / α, x.derivative / α)
end
# ╔═╡ 80356507-0b92-4c62-8bdf-865e345a29dc
md"Par la product rule ``(fg)' = f'g + fg'``:"
# ╔═╡ c0caef28-d59a-43a1-af4f-6756c3b41903
md"Pour l'exponentiation, on peut juste se rabatter sur le produit qu'on a déjà défini:"
# ╔═╡ a2ac721c-700e-4bbf-8c13-3b06db292c00
Base.:^(x::Dual, n::Integer) = Base.power_by_squaring(x, n)
# ╔═╡ 3f7cfa28-b060-4a3e-b61a-fd42be8e6939
onehot(1, 1:2)
# ╔═╡ 7d805d6a-9077-4d97-a0db-c1bd306cbbb8
float.(onehot(1, 1:2))
# ╔═╡ 7dfc8a90-5a7d-4457-b382-f9552e02fd73
float.(onehot(2, 1:2))
# ╔═╡ 42eacb3a-54b0-43e8-97a1-07a71ac3faf5
Dual.(w, onehot(1, 1:2))
# ╔═╡ 42f15c09-49f7-40e0-8892-0ade61a3c923
function forward_diff(loss, w, X, y, i)
loss(Dual.(w, onehot(i, eachindex(w))), X, y).derivative
end
# ╔═╡ 4dde16bc-2c40-4214-8963-2d7a7287f587
function forward_diff(loss, w, X, y)
[forward_diff(loss, w, X, y, i) for i in eachindex(w)]
end
# ╔═╡ b7303267-3404-4542-a7f8-5960859abc19
md"""
## Gradient descent
### Cauchy-Schwarz inequality
```math
\begin{align}
\left(\sum_i x_i y_i\right)^2 & = \left(\sum_i x_i^2\right)\left(\sum_i y_i^2\right)\cos(\theta)^2\\
\sum_i x_i y_i & = \sqrt{\sum_i x_i^2}\sqrt{\sum_i y_i^2}\cos(\theta)\\
\langle x, y \rangle & = \|x\|_2 \|y\|_2\cos(\theta)\\
-\|x\|_2 \|y\|_2 \le \langle x, y \rangle & \le \|x\|_2 \|y\|_2
\end{align}
```
Minimum atteint lorsque ``x = -y`` et maximum atteint lorsque ``x = y``.
Dans les deux cas, ``x`` est **parallèle** à ``y``, mais ce sont des **sens** différents.
### Rappel dérivée directionnelle
Dérivée dans la direction ``d``:
```math
\langle d, \nabla f \rangle
=
d^\top \nabla f
=
d_1 \cdot \partial f/\partial x_1 + \cdots + d_n \cdot \partial f/\partial x_n
```
Etant donné un gradient ``\nabla f``, la direction ``d`` telle que ``\|d\|_2 = 1`` qui a une dérivée minmale est ``d = -\nabla f``.
"""
# ╔═╡ b025aa9e-1137-4201-b7b8-2b803f8aa17e
md"Gradient:"
# ╔═╡ 0b4d8741-e9fb-40e4-a8be-cf21683b8f79
md"Gradient line search: pendant combien de temps doit-on suivre le gradient ?"
# ╔═╡ 53fe00ff-32d1-4c61-ae89-e54df5efc3a0
@bind num_η Slider(5:50, default=10, show_value = true)
# ╔═╡ 7ad77ed6-de39-430e-9e71-1c3d37ce7f34
step_sizes = range(-1, stop=1, length=num_η)
# ╔═╡ 52a049a9-c50f-426c-83e5-4ec02d5a638c
md"`num_iters` = $(@bind num_iters Slider(1:20, default = 10, show_value = true))"
# ╔═╡ b1aa765e-7a6b-4ab4-af83-ab3d30497866
md"## Kernel trick"
# ╔═╡ f71923d4-fbc9-4ce6-b5be-a00437c3651d
md"`η_lift` = $(@bind η_lift Slider(exp10.(-4:0.25:1), default=0.01, show_value = true))"
# ╔═╡ 18edd949-ce86-431a-a19b-4daf526e57a6
md"`num_iters_lift` = $(@bind num_iters_lift Slider(1:400, default=200, show_value = true))"
# ╔═╡ dc4feb58-d2cf-4a97-aaed-7f4593fc9732
md"""
### L1 norm
La fonction ``|x|`` n'est pas différentiable lorsque ``x = 0``.
Si on s'approche par la gauche (c'est à dire ``x < 0``, la fonction est ``-x``) donc la dérivée vaut ``-1``.
Si on s'approche par la droite (c'est à dire ``x > 0``, la fonction est ``x``) donc la dérivée vaut ``1``.
Il n'y a pas de gradient valide!
Par contre, n'importe quel nombre entre ``-1`` et ``1`` est un **subgradient** valide ! Alors que le gradient est la normale à la tangente **unique**, le subgradient est un élément du **cone tangent**.
"""
# ╔═╡ f5749121-8e75-45de-95b9-63fff584e350
md"`η_L1` = $(@bind η_L1 Slider(exp10.(-4:0.25:1), default=0.1, show_value = true))"
# ╔═╡ 66e36fb8-5a61-49a7-8053-911fd887b0a9
md"`num_iters_L1` = $(@bind num_iters_L1 Slider(1:400, default=200, show_value = true))"
# ╔═╡ db28bb45-3418-4080-a0fc-9136fc0196a5
md"""
## Reverse diff
Le désavantage de la forward differentiation, c'est qu'il faut recommencer tout le calcul pour calculer la dérivée par rapport à chaque variable. La *reverse differentiation*, aussi appelée *backpropagation*, résoud se problème en calculer la dérivée par rapport à toutes les variables en une fois!
### Chain rule
#### Exemple univarié
Commençons par un exemple univarié pour introduire le fait qu'il existe un choix dans l'ordre de la multiplication des dérivées. La liberté introduite par ce choix donne lieu à la différence entre la différentiation *forward* et *reverse*.
Supposions qu'on veuille dériver la fonction ``\tan(\cos(\sin(x)))`` pour ``x = \pi/3``. La Chain Rule nous donne:
```math
\begin{align}
(\tan(\cos(\sin(x))))'
& = \left. (\tan(x))' \right|_{x = \cos(\sin(x)))} (\cos(\sin(x))))'\\
& = \left. (\tan(x))' \right|_{x = \cos(\sin(x)))}
\left. (\cos(x))' \right|_{x = \sin(x))}
(\sin(x)))'\\
& = \frac{1}{\cos^2(\cos(\sin(x)))} (-\sin(\sin(x))) \cos(x)
\end{align}
```
La dérivée pour ``x = \pi/3`` est donc:
```math
\begin{align}
\left. (\tan(\cos(\sin(x))))' \right|_{x = \pi/3}
& =
\left. (\tan(x))' \right|_{x = \cos(\sin(\pi/3)))}
\left. (\cos(x))' \right|_{x = \sin(\pi/3))}
\left. (\sin(x)))' \right|_{x = \pi/3}\\
& = \frac{1}{\cos^2(\cos(\sin(\pi/3)))} (-\sin(\sin(\pi/3))) \cos(\pi/3)
\end{align}
```
Pour calculer ce produit de 3 nombres, on a 2 choix.
La première possibilité (qui correspond à forward diff) est de commencer par calculer le produit
```math
\begin{align}
\left. (\cos(\sin(x)))' \right|_{x = \pi/3}
& =
\left. (\cos(x))' \right|_{x = \sin(\pi/3))}
\left. (\sin(x)))' \right|_{x = \pi/3}\\
& =
(-\sin(\sin(\pi/3))) \cos(\pi/3)
\end{align}
```
puis de le multiplier avec ``\left. (\tan(x))' \right|_{x = \cos(\sin(\pi/3)))} = \frac{1}{\cos^2(\cos(\sin(\pi/3)))}``.
La deuxième possibilité (qui correspond à reverse diff) est de commencer par calculer le produit
```math
\begin{align}
\left. (\tan(\cos(x)))' \right|_{\textcolor{red}{x = \sin(\pi/3)}}
& =
\left. (\tan(x))' \right|_{\textcolor{red}{x = \cos(\sin(\pi/3)))}}
\left. (\cos(x))' \right|_{\textcolor{red}{x = \sin(\pi/3))}}\\
& = \frac{1}{\cos^2(\cos(\sin(\pi/3)))} (-\sin(\sin(\pi/3)))
\end{align}
```
puis de le multiplier avec ``\cos(\pi/3)``.
Vous remarquerez que dans l'équation ci-dessus, comme mis en évidence en rouge, les valeurs auxquelles les dérivées doivent être évaluées dépendent de ``\sin(\pi/3)``.
L'approche utilisée par reverse diff de multiplier de gauche à droite ne peut donc pas être effectuer sans prendre en compte la valeur qui doit être évaluée de droite à gauche.
Pour appliquer reverse diff, il faut donc commencer par une *forward pass* de droite à gauche qui calcule ``\sin(\pi/3)`` puis ``\cos(\sin(\pi/3))`` puis ``\tan(\cos(\sin(\pi/3)))``. On peut ensuite faire la *backward pass* qui multipliée les dérivée de gauche à droite. Afin d'être disponibles pour la backward pass, les valeurs calculées lors de la forward pass doivent être **stockées** ce qui implique un **coût mémoire**.
En revanche, comme forward diff calcule la dérivée dans le même sens que l'évaluation, les dérivées et évaluations peuvent être calculées en même temps afin de ne pas avoir besoin de stocker les évaluations. C'est effectivement ce qu'on a implémenter avec `Dual` précédemment.
Au vu de ce coût mémoire supplémentaire de reverse diff par rapport à forward diff,
ce dernier paraît préférable en pratique.
On va voir maintenant que dans le cas multivarié, dans certains cas, ce désavantage est contrebalancé par une meilleure complexité temporelle qui rend reverse diff indispensable!
"""
# ╔═╡ 494fc7d7-c622-41d1-91d8-3dc1fbd2f244
md"""
#### Exemple multivarié
Prenons maintenant un example multivarié, supposons qu'on veuille calculer le gradient de la fonction ``f(g(h(x_1, x_2)))`` qui compose 3 fonctions ``f``, ``g`` et ``h``.
Le gradient est obtenu via la chain rule comme suit:
```math
\begin{align}
\frac{\partial}{\partial x_1} f(g(h(x_1, x_2)))
& = \frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_1}\\
\frac{\partial}{\partial x_2} f(g(h(x_1, x_2)))
& = \frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_2}\\
\nabla_{x_1, x_2} f(g(h(x_1, x_2)))
& = \begin{bmatrix}
\frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_1} &
\frac{\partial f}{\partial g} \frac{\partial g}{\partial h} \frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial g}{\partial h}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial h}{\partial x_1} &
\frac{\partial h}{\partial x_2}
\end{bmatrix}
\end{align}
```
On voit que c'est le produit de 3 matrices. Forward diff va exécuter ce produit de droite à gauche:
```math
\begin{align}
\nabla_{x_1, x_2} f(g(h(x_1, x_2)))
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial g}{\partial h}\frac{\partial h}{\partial x_1} &
\frac{\partial g}{\partial h}\frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial g}{\partial x_1} &
\frac{\partial g}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}\frac{\partial g}{\partial x_1} &
\frac{\partial f}{\partial g}\frac{\partial g}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial x_1} &
\frac{\partial f}{\partial x_2}
\end{bmatrix}
\end{align}
```
L'idée de reverse diff c'est d'effectuer le produit de gauche à droite:
```math
\begin{align}
\nabla_{x_1, x_2} f(g(h(x_1, x_2)))
& =
\begin{bmatrix}
\frac{\partial f}{\partial g}\frac{\partial g}{\partial h}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial h}{\partial x_1} &
\frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial h}
\end{bmatrix}
\begin{bmatrix}
\frac{\partial h}{\partial x_1} &
\frac{\partial h}{\partial x_2}
\end{bmatrix}\\
& =
\begin{bmatrix}
\frac{\partial f}{\partial x_1} &
\frac{\partial f}{\partial x_2}
\end{bmatrix}
\end{align}
```
"""
# ╔═╡ 4e1ac5fc-c684-42e1-9c99-3120021eb19a
md"""
Pour calculer ``\partial f / \partial x_1`` via forward diff, on part donc de ``\partial x_1 / \partial x_1 = 1`` et ``\partial x_2 / \partial x_1 = 0`` et on calcule ensuite ``\partial h / \partial x_1``, ``\partial g / \partial x_1`` puis ``\partial f / \partial x_1``.
Effectuer la reverse diff est un peu moins intuitif. L'idée est de partir de la dérivée du résultat par rapport à lui même ``\partial f / \partial f = 1`` et de calculer ``\partial f / \partial g`` puis ``\partial f / \partial h`` et ensuite ``\partial f / \partial x_1``. L'avantage de reverse diff c'est qu'il n'y a que la dernière étape qui est sécifique à ``x_1``. Tout jusqu'au calcul de ``\partial f / \partial h`` peut être réutilisé pour calculer ``\partial f / \partial x_2``, il n'y a plus cas multiplier ! Reverse diff est donc plus efficace pour calculer le gradient d'une fonction qui a une seul output par rapport à beaucoup de paramètres comme détaillé dans la discussion à la fin de ce notebook.
"""
# ╔═╡ 56b32132-113f-459f-b1d9-abb8f439a40b
md"""
### Forward pass : Construction de l'expression graph
Pour implémenter reverse diff, il faut construire l'expression graph pour garder en mémoire les valeurs des différentes expressions intermédiaires afin de pouvoir calculer les dérivées locales ``\partial f / \partial g`` et ``\partial g / \partial h``. Le code suivant défini un noeud de l'expression graph. Le field `derivative` correspond à la valeur de ``\partial f_{\text{final}} / \partial f_{\text{node}}`` où ``f_\text{final}`` est la dernière fonction de la composition et ``f_{\text{node}}`` est la fonction correspondant au node.
"""
# ╔═╡ 4931adf1-8771-4708-833e-d05c05884969
begin
mutable struct Node
op::Union{Nothing,Symbol}
args::Vector{Node}
value::Float64
derivative::Float64
end
Node(op, args, value) = Node(op, args, value, NaN)
Node(value) = Node(nothing, Node[], value)
end
# ╔═╡ 0b07b9cf-83b4-46e9-9a75-cf2cadbbb011
md"""
L'operateur overloading suivant sera sufficant pour construire l'expression graph dans le cadre de ce notebook, vous l'étendrez pendant la séance d'exercice.
"""
# ╔═╡ b814dc16-37de-45d1-9c7c-4eec45d3f956
begin
Base.zero(x::Node) = Node(0.0)
Base.:*(x::Node, y::Node) = Node(:*, [x, y], x.value * y.value)
Base.:+(x::Node, y::Node) = Node(:+, [x, y], x.value + y.value)
Base.:-(x::Node, y::Node) = Node(:-, [x, y], x.value - y.value)
Base.:/(x::Node, y::Number) = x * Node(inv(y))
Base.:^(x::Node, n::Integer) = Base.power_by_squaring(x, n)
Base.sin(x::Node) = Node(:sin, [x], sin(x.value))
Base.cos(x::Node) = Node(:cos, [x], cos(x.value))
end
# ╔═╡ 851e688f-2b30-44b7-9530-87990adee4b2
Base.:*(x::Dual{T}, y::Dual{T}) where {T} = Dual(x.value * y.value, x.value * y.derivative + x.derivative * y.value)
# ╔═╡ 2d469929-da96-4b07-a5a4-defa3d253c81
mse(w, X, y) = sum((X * w - y).^2 / length(y))
# ╔═╡ 15390f36-bc62-4c25-a866-5641aecc86ed
function train!(diff, loss, w0, X, y, η, num_iters)
w = copy(w0)
training_losses = [loss(w, X, y)]
for _ in 1:num_iters
∇ = diff(loss, w, X, y)
w .= w .- η .* ∇
push!(training_losses, loss(w, X, y))
end
return w, training_losses
end
# ╔═╡ a767d45f-a438-4d87-bdab-7d55ea7458ac
lift(x) = [1.0, x[1], x[2], x[1]^2, x[1] * x[2], x[2]^2, x[1]^3, x[1]^2 * x[2], x[1] * x[2]^2, x[2]^3]
# ╔═╡ 42644265-8f26-4118-9e8f-537078847af7
function Base.abs(d::Dual)
if d.value < 0
return -1.0 * d
else
return d
end
end
# ╔═╡ 607000ef-fb7f-4204-b543-3cb6bb75ed71
let
x = range(-1, stop = 1, length = 11)
p = plot(x, abs, axis = :equal, label = "|x|")
for λ in range(0, stop = 1, length = 11)
plot!([0, λ/2 - (1 - λ)/2], [0, -1/2], arrow = Plots.arrow(:closed), label = "")
end
p
end
# ╔═╡ b899a93f-9bec-48ce-b0ad-4e5157556a31
L1_loss(w, X, y) = sum(abs.(X * w - y)) / length(y)
# ╔═╡ 1572f901-c688-435e-81b9-d6e39bb82201
md"""
On crée les leafs correspondant aux variables ``x_1`` et ``x_2`` de valeur ``1`` et ``2`` respectivement. Les valeurs correspondent aux valeurs de ``x_1`` et ``x_2`` auxquelles on veut dériver la fonction. On a choisi 1 et 2 pour pouvoir les reconnaitre facilement dans le graphe.
"""
# ╔═╡ b7faa3b7-e0b6-4f55-8763-035d8fc5ac93
x_nodes = Node.([1, 2])
# ╔═╡ 194d3a68-6bed-41d9-b3ea-8cfaf4787c54
expr = cos(sin(prod(x_nodes)))
# ╔═╡ 7285c7e8-bce0-42f0-a53b-562a8a6c5894
function _nodes!(nodes, x::Node)
if !(x in keys(nodes))
nodes[x] = length(nodes) + 1
end
for arg in x.args
_nodes!(nodes, arg)
end
end
# ╔═╡ b9f0e9b6-c111-4d69-990f-02c460c8706d
function _edges(g, labels, nodes::Dict{Node}, done, x::Node)
id = nodes[x]
if done[id]
return
end
done[id] = true
labels[id] = @sprintf "%.2f" x.value
if !isnothing(x.op)
labels[id] = "[" * String(x.op) * "] " * labels[id]
end
for arg in x.args
add_edge!(g, id, nodes[arg])
_edges(g, labels, nodes, done, arg)
end
end
# ╔═╡ 114c048f-f619-4e15-8e5a-de852b0a1861
function graph(x::Node)
nodes = Dict{Node,Int}()
_nodes!(nodes, x)
done = falses(length(nodes))
g = DiGraph(length(nodes))
labels = Vector{String}(undef, length(nodes))
_edges(g, labels, nodes, done, x)
return g, labels
end
# ╔═╡ 7578be43-8dbe-4041-adc7-275f06057bfe
md"""
Pour le visualiser, on le converti en graphe utilisant la structure de donnée de Graphs.jl pour pouvoir utiliser `gplot`
"""
# ╔═╡ 69298293-c9fc-432f-9c3c-5da7ce710334
expr_graph, labels = graph(expr)
# ╔═╡ 9b9e5fc3-a8d0-4c42-91ae-25dd01bc7d7e
gplot(expr_graph, nodelabel = labels)
# ╔═╡ 54af5dab-e669-45eb-b6b5-44c46f7258b4
md"""
#### Combinaison des dérivées
Que faire si plusieurs expressions dépendent d'une même variable.
Considérons l'exemple ``f(x) = \sin(x)\cos(x)`` qui correspond à ``f(g,h) = gh``, ``g(x) = \sin(x)`` et ``h(x) = \cos(x)``.
La chain rule donne
```math
\begin{align}
f'(x) & = \frac{\partial f}{\partial g}g'(x) + \frac{\partial f}{\partial h}h'(x)
\end{align}
```
Une fois la valeur ``\partial f / \partial g`` calculée, on peut la multiplier par ``g'(x)`` pour avoir la première partie partie de ``f'(x)``. Idem pour ``h``. Ces deux contributions seront calculée séparément lors de la backward pass sur le noeud ``g`` et ``h``. On voit par la formule de la chain rule que ces deux contributions doivent être sommées.
Lors de la backward pass, on initialise donc toutes les dérivées à 0. Pour chaque contribution, on ajoute la dérivée avec `+=`. On s'assure ensuite qu'on ne procède pas à la backward pass sur un noeud avant qu'il ait fini d'accumuler les contributions de toutes les expressions qui en dépendent via un *tri topologique*.
"""
# ╔═╡ 842df050-e0be-441d-b84e-7f0575eac227
x_node = Node(1)
# ╔═╡ 36fcfd5e-c521-42fc-96cd-93787e657627
sin_cos = sin(x_node) * cos(x_node)
# ╔═╡ f1da5c61-9862-44f6-84e4-96217736e1cb
sin_cos_graph, sin_cos_labels = graph(sin_cos)
# ╔═╡ 363af09c-3cc9-440d-9b70-1da8f6a70913
gplot(sin_cos_graph, nodelabel = sin_cos_labels)
# ╔═╡ bd705ddd-0d00-41a0-aa55-e82daad4133d
md"### Backward pass : Calcul des dérivées"
# ╔═╡ d8052188-f2fa-4ad8-935f-581eea164bda
md"""
La fonction suivante propage la dérivée ``\partial f_{\text{final}} / \partial f_{\text{node}}`` à la dérivée des arguments de la fonction ``f_{\text{node}}``.
Comme les arguments peuvent être utilisés à par d'autres fonction, on somme la dérivée avec `+=`.
"""
# ╔═╡ 1e08b49d-03fe-4fb3-a8ba-3a00e1374b32
function _backward!(f::Node)
if isnothing(f.op)
return
elseif f.op == :+
for arg in f.args
arg.derivative += f.derivative
end
elseif f.op == :- && length(f.args) == 2
f.args[1].derivative += f.derivative
f.args[2].derivative -= f.derivative
elseif f.op == :* && length(f.args) == 2
f.args[1].derivative += f.derivative * f.args[2].value
f.args[2].derivative += f.derivative * f.args[1].value
elseif f.op == :sin
f.args[].derivative += f.derivative * cos(f.args[].value)
elseif f.op == :cos
f.args[].derivative -= f.derivative * sin(f.args[].value)
else
error("Operator `$(f.op)` not supported yet")
end
end
# ╔═╡ 44442c34-e088-493a-bfd6-9c095c499100
md"""
La fonction `_backward!` ne doit être appelée que sur un noeud pour lequel `f.derivative` a déjà été calculé. Pour cela, `_backward!` doit avoir été appelé sur tous les noeuds qui représente des fonctions qui dépendent directement ou indirectement du résultat du noeud.
Pour trouver l'ordre dans lequel appeler `_backward!`, on utilise donc on tri topologique (nous reviendrons sur les tris topologique dans la partie graphe).
"""
# ╔═╡ 86872f35-d62d-40e5-8770-4585d3b0c0d7
function topo_sort!(visited, topo, f::Node)
if !(f in visited)
push!(visited, f)
for arg in f.args
topo_sort!(visited, topo, arg)
end
push!(topo, f)
end
end
# ╔═╡ 26c40cf4-9585-4762-abf4-ff77342a389f
function backward!(f::Node)
topo = typeof(f)[]
topo_sort!(Set{typeof(f)}(), topo, f)
reverse!(topo)
for node in topo
node.derivative = 0
end
f.derivative = 1
for node in topo
_backward!(node)
end
return f
end
# ╔═╡ 86fc7924-2002-4ac6-8e02-d9bf5edde9bf
backward!(expr)
# ╔═╡ d72d9c99-6280-49a2-9f7a-e9628f1069eb
md"On a maintenant l'information sur les dérivées de `x_nodes`:"
# ╔═╡ 0649437a-4198-4556-97dc-1b5cfbe45eed
x_nodes
# ╔═╡ 3928e9f7-9539-4d99-ac5b-6336eff8a523
md"""
### Comparaison avec Forward Diff dans l'exemple moon
Revenons sur l'exemple utilisé pour illustrer la forward diff et essayons de calculer la même dérivée mais à présent en utiliser reverse diff.
"""
# ╔═╡ 5ce7fbad-af38-4ff6-adca-b1991f3be455
w_nodes = Node.(w)
# ╔═╡ 0dd8e1bf-f8c0-4183-a5eb-13eeb5316a7b
mse_expr = backward!(mse(Node.(ones(1)), Node.(2ones(1, 1)), Node.(3ones(1))))
# ╔═╡ 7a320b75-c104-43d1-9129-f7f53910f5bc
function reverse_diff(loss, w, X, y)
w_nodes = Node.(w)
expr = loss(w_nodes, Node.(X), Node.(y))
backward!(expr)
return [w.derivative for w in w_nodes]
end
# ╔═╡ a610dc3c-803a-4489-a84b-8bff415bc0a6
md"We execute it a second time to get rid of the compilation time:"
# ╔═╡ 0e99048d-5696-43ab-8896-301f37a20a5d
md"On remarque que reverse diff est plus lent! IL y a un certain coût mémoire lorsqu'on consruit l'expression graph. Pour cette raison, si on veut calculer plusieurs dérivées consécutives pour différentes valeurs de ``x_1`` et ``x_2``, on a intérêt à garder le graphe et à uniquement changer la valeur des variables plutôt qu'à reconstruire le graphe à chaque fois qu'on change les valeurs. Alternativement, on peut essayer de condenser le graphe en exprimant les opérations sur des large matrices ou même tenseurs, c'est l'approche utilisée par pytorch ou tensorflow."
# ╔═╡ bd012d84-a79f-4043-961e-f7825b7e0d6c
md"`num_data` = $(@bind(num_data, Slider(1:100, default = 32, show_value = true)))"
# ╔═╡ 03f6d241-712d-4a45-b926-09be326c1c7d
md"`num_features` = $(@bind(num_features, Slider(1:100, default = 32, show_value = true)))"
# ╔═╡ e9507958-cefd-4208-896e-860d3e4e9d4b
md"`num_hidden` = $(@bind(num_hidden, Slider(1:100, default = 32, show_value = true)))"
# ╔═╡ 8acebedb-3a97-4efd-bd3c-c267ecd3945c
mse2(W1, W2, X, y) = sum((X * W1 * W2 - y).^2 / length(y))
# ╔═╡ 32fdcdd9-785d-427b-94c3-3c65bf72e673
function bench(num_data, num_features, num_hidden)
X = rand(num_data, num_features)
W1 = rand(num_features, num_hidden)
W2 = rand(num_hidden)
y = rand(num_data)
@time for i in axes(W1, 1)
for j in axes(W1, 2)
mse2(
Dual.(W1, onehot(i, axes(W1, 1)) * onehot(j, axes(W1, 2))'),
W2,
X,
y,
)
end
end
expr = @time mse2(Node.(W1), Node.(W2), Node.(X), Node.(y))
@time backward!(expr)
return
end
# ╔═╡ 8ce88d4d-59b0-49bb-8c0a-3a2961e5fd4a
bench(num_data, num_features, num_hidden)
# ╔═╡ 613269ef-16ba-44ef-ad8e-997cc9aec1fb
md"""
### Comment choisir entre forward et reverse diff ?
Suppose that we need to differentiate a composition of functions:
``(f_n \circ f_{n-1} \circ \cdots \circ f_2 \circ f_1)(w)``.
For each function, we can compute a jacobian given the value of its input.
So, during a forward pass, we can compute all jacobians. We now just need to take the product of these jacobians:
```math
J_n J_{n-1} \cdots J_2 J_1
```
While the product of matrices is associative, its computational complexity depends on the order of the multiplications!
Let ``d_i \times d_{i - 1}`` be the dimension of ``J_i``.
#### Forward diff: from right to left
If the product is computed from right to left:
```math
\begin{align}
J_{1,2} & = J_2 J_1 && \Omega(d_2d_1d_0)\\
J_{1,3} & = J_3 J_{1,2} && \Omega(d_3d_2d_0)\\
J_{1,4} & = J_4 J_{1,3} && \Omega(d_4d_3d_0)\\
\vdots & \quad \vdots\\
J_{1,n} & = J_n J_{1,(n-1)} && \Omega(d_nd_{n-1}d_0)
\end{align}
```
we have a complexity of
```math
\Omega(\sum_{i=2}^n d_id_{i-1}d_0).
```
#### Reverse diff: from left to right
Reverse differentation corresponds to multiplying the adjoint from right to left or equivalently the original matrices from left to right.
This means computing the product in the following order:
```math
\begin{align}
J_{(n-1),n} & = J_n J_{n-1} && \Omega(d_nd_{n-1}d_{n-2})\\
J_{(n-2),n} & = J_{(n-1),n} J_{n-2} && \Omega(d_nd_{n-2}d_{n-3})\\
J_{(n-3),n} & = J_{(n-2),n} J_{n-3} && \Omega(d_nd_{n-3}d_{n-4})\\
\vdots & \quad \vdots\\
J_{1,n} & = J_{2,n} J_1 && \Omega(d_nd_1d_0)\\
\end{align}
```
We have a complexity of
```math
\Omega(\sum_{i=1}^{n-1} d_nd_id_{i-1}).
```
#### Mixed : from inward to outward
Suppose we multiply starting from some ``d_k`` where ``1 < k < n``.
We would then first compute the left side:
```math
\begin{align}
J_{k+1,k+2} & = J_{k+2} J_{k+1} && \Omega(d_{k+2}d_{k+1}d_{k})\\
J_{k+1,k+3} & = J_{k+3} J_{k+1,k+2} && \Omega(d_{k+3}d_{k+2}d_{k})\\
\vdots & \quad \vdots\\
J_{k+1,n} & = J_{n} J_{k+1,n-1} && \Omega(d_nd_{n-1}d_k)
\end{align}
```
then the right side:
```math
\begin{align}
J_{k-1,k} & = J_k J_{k-1} && \Omega(d_kd_{k-1}d_{k-2})\\
J_{k-2,k} & = J_{k-1,k} J_{k-2} && \Omega(d_kd_{k-2}d_{k-3})\\
\vdots & \quad \vdots\\
J_{1,k} & = J_{2,k} J_1 && \Omega(d_kd_1d_0)\\
\end{align}
```
and then combine both sides:
```math
J_{1,n} = J_{k+1,n} J_{1,k} \qquad \Omega(d_nd_kd_0)
```
we have a complexity of
```math
\Omega(d_nd_kd_0 + \sum_{i=1}^{k-1} d_kd_id_{i-1} + \sum_{i=k+2}^{n} d_id_{i-1}d_k).
```
#### Comparison
We see that we should find the minimum ``d_k`` and start from there. If the minimum is attained at ``k = n``, this corresponds mutliplying from left to right, this is reverse differentiation. If the minimum is attained at ``k = 0``, we should multiply from right to left, this is forward differentiation. Otherwise, we should start from the middle, this would mean mixing both forward and reverse diff.
What about neural networks ? In that case, ``d_0`` is equal to the number of entries in ``W_1`` added with the number of entries in ``W_2`` while ``d_n`` is ``1`` since the loss is scalar. We should therefore clearly multiply from left to right hence do reverse diff.
"""
# ╔═╡ fdd28672-5902-474a-8c87-3f6f38bcf54f
md"""
## Pour la séance d'exercices:
"""
# ╔═╡ 54697e82-ee8c-4b65-a633-b29a47fac722
md"### tanh
Passer un réseau de neurone avec fonction d'activation `tanh`
"
# ╔═╡ b5286a91-597d-4b66-86b2-1528708dfa93
W2 = rand(num_hidden)
# ╔═╡ 24b95ecb-0df7-4af2-9d7e-c76ce380c6a6
md"### ReLU
Passer un réseau de neurone avec fonction d'activation ReLU"
# ╔═╡ 21b97f3d-85c8-420b-a884-df4fed98b8d0
function relu(x)
if x < 0
return 0
else
return x
end
end
# ╔═╡ 73cbbbd0-8427-46d3-896b-7cfc732c35f7
md"""
### One-Hot encoding et cross-entropy
"""
# ╔═╡ c1da4130-5936-499f-bb9b-574e01136eca
md"### Acknowledgements and further readings
* `Dual` est inspiré de [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl)
* `Node` est inspiré de [micrograd](https://github.com/karpathy/micrograd)
* Une bonne intro à l'automatic differentiation est disponible [ici](https://gdalle.github.io/AutodiffTutorial/)
"
# ╔═╡ b16f6225-1949-4b6d-a4b0-c5c230eb4c7f
md"## Utils"
# ╔═╡ dad5cba4-9bc6-47c3-a932-f2cc496b0f40
import MLJBase, Colors, Tables
# ╔═╡ 0a6b0db0-03f7-4077-b330-5f27f7c7a9a2
X_table, y_cat = MLJBase.make_moons(100, noise=0.1)
# ╔═╡ 7535debf-13ff-4bef-92fb-cdbf7fef7515
y = 2(float.(y_cat.refs) .- 1.5)
# ╔═╡ f90fd2b4-7aec-40f5-85b7-00674582a902
Y = unique!(sort(y_cat.refs))' .== y_cat.refs
# ╔═╡ d893fb13-fadb-452c-ba93-96f630bab3cd
function plot_w(w = nothing)
col = [Colors.JULIA_LOGO_COLORS.red, Colors.JULIA_LOGO_COLORS.blue]
p = scatter(X_table.x1, X_table.x2, markerstrokewidth=0, color = col[y_cat.refs], label = "")
if isnothing(w)
return p
elseif length(w) == 2
plot!([extrema(X_table.x1)...], x1 -> -w[1] * x1 / w[2], label = "", color = Colors.JULIA_LOGO_COLORS.green, linewidth = 2)
else
x1 = range(minimum(X_table.x1), stop = maximum(X_table.x1), length = 30)
x2 = range(minimum(X_table.x2), stop = maximum(X_table.x2), length = 30)
contour!(x1, x2, (x1, x2) -> w' * lift([x1, x2]), label = "", colorbar_ticks=([1], [0.0]))
end
end
# ╔═╡ 36401e91-1121-4552-bc4e-1b1aac76d1e0
plot_w()
# ╔═╡ 054a167c-1d5c-4c8a-9977-22b4e4d5f05d
plot_w(w)
# ╔═╡ 28349f82-f2e7-46ea-82b5-306e9dbb6daa
X = Tables.matrix(X_table)
# ╔═╡ 76c147ac-73c1-421f-9c9d-681a91e02b91
y_est = X * w
# ╔═╡ 03bc2513-e542-4ac8-8edc-8fdf3793b834
errors = y_est - y
# ╔═╡ 0af91567-90ae-47d2-9d5a-9f33b623a204
mean_squared_error = sum(errors.^2) / length(errors)
# ╔═╡ 96e35052-8697-495c-8ded-5f6348b7e711
mse(w, X, y)
# ╔═╡ 2904c070-d595-4cff-a507-f360f1e774dd
mse(Dual.(w, onehot(1, 1:2)), X, y)
# ╔═╡ d0bc2ab9-9c9f-4c60-bfa7-cbcdffc20ba6
forward_diff(mse, w, X, y, 1)
# ╔═╡ be3404c5-a2b6-4df9-ae31-d5a8b6ec1c92
∇ = forward_diff(mse, w, X, y)
# ╔═╡ 06870328-9a78-4b5f-a405-7654af38035b
losses = [mse(w + η * ∇, X, y) for η in step_sizes]
# ╔═╡ 604cdb54-6060-46a1-b4cb-22fb58394720
best_idx = argmin(losses)
# ╔═╡ 1ad80bdb-84d3-418d-8bda-c3decbe08ae0
η_best = step_sizes[best_idx]
# ╔═╡ b85211fa-bcf4-4b4e-b028-b49361a92c2c
w_improved = w + η_best * ∇
# ╔═╡ 57a78422-066b-4cda-a902-b4b34af375e6
best_loss = losses[best_idx]
# ╔═╡ 92250d6f-3a95-475e-8532-c98e3bea63e2
begin
plot(step_sizes, losses, label = "")
scatter!([0.0], [mse(w, X, y)], markerstrokewidth = 0, label = "")
scatter!([η_best], [best_loss], markerstrokewidth = 0, label = "")
end
# ╔═╡ 33654f91-0b61-4b39-b0fb-ef21776f0dfc
w_trained, training_losses = train!(forward_diff, mse, w, X, y, 0.7, num_iters)
# ╔═╡ 56558e41-e3ae-42e5-8695-3d2077cbc10c
plot(0:num_iters, training_losses, label = "")
# ╔═╡ c1059a32-8d65-4957-989f-2c5b5f50eb81
plot_w(w_trained)
# ╔═╡ 403f2e6c-5dec-4206-a5c0-d4a90fafc029
X_lift = reduce(vcat, transpose.(lift.(eachrow(X))))
# ╔═╡ f56a567b-62db-43b6-8dce-db9d5ca1e348
w_lift = rand(size(X_lift, 2))
# ╔═╡ b4b31659-8dc9-4837-8a6f-8d40b41b2366
plot_w(w_lift)
# ╔═╡ d69b0cfe-ce30-420d-a891-c6e35e11cdde
w_lift_trained, training_losses_lift = train!(forward_diff, mse, w_lift, X_lift, y, η_lift, num_iters_lift)
# ╔═╡ c1e29f52-424b-4cb2-814e-b4aecda86bc6
plot(0:num_iters_lift, training_losses_lift, label = "")
# ╔═╡ 7606f18e-9833-48cd-8c21-b48be30ef6f8
plot_w(w_lift_trained)
# ╔═╡ 91e5840a-8a18-4d36-8fce-510af4c7dcf2
w_trained_L1, training_losses_L1 = train!(forward_diff, L1_loss, w_lift, X_lift, y, η_L1, num_iters_L1)
# ╔═╡ 4a807c44-f33e-4a58-99b3-261c392be891
plot(0:num_iters_L1, training_losses_L1, label = "")
# ╔═╡ b5f3c2b1-c86a-48e4-973b-ee639d784936
plot_w(w_trained_L1)
# ╔═╡ c1b73208-f917-4823-bf45-d896f4ee59e0
@time forward_diff(mse, w, X, y)
# ╔═╡ 4444622b-ccfe-4867-b659-489573099f1e
@time reverse_diff(mse, w, X, y)
# ╔═╡ c1942ee2-c5af-4b2b-986d-6ad563ef27bb
@time reverse_diff(mse, w, X, y)
# ╔═╡ 0af6bce6-bc3b-438e-a57b-0c0c6586c0c5
W1 = rand(size(X, 2), num_hidden)
# ╔═╡ c3442bea-4ba9-4711-8d70-0ddd1745dd58
y_est_tanh = tanh.(X * W1) * W2
# ╔═╡ 83180cd9-21b1-4e70-8d54-4bf03efa31be
y_est_relu = relu.(X * W1) * W2
# ╔═╡ a05e3426-f014-473c-aad1-1cc6351a8911
W = rand(size(X, 2), size(Y, 2))
# ╔═╡ ba09be29-2af6-432e-871b-0637c52c68cf
Y_1 = X * W
# ╔═╡ 7a4156bd-5316-43ca-9cdc-d7e39f3c152f
Y_2 = exp.(X * W)
# ╔═╡ 82ab5eed-eefe-4421-9a1e-7c3b0eea0523
sums = sum(Y_2, dims=2)
# ╔═╡ e16e9464-0a27-49ef-9b17-5ea850d253b4
Y_est = Y_2 ./ sums
# ╔═╡ f761b97e-f1cb-406a-b0e9-75e4ec665097
cross = Y_est .* Y
# ╔═╡ f714edaa-ba9a-46eb-89f0-5bcc4120015e
cross_entropies = -log.(sum(Y_est .* Y, dims=2))
# ╔═╡ 1b7f3578-1de5-498c-b70f-9b4aa20d2edd
-log.(getindex.(Ref(Y_est), axes(Y_est, 1), y_cat.refs))
# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
Colors = "~0.12.11"
GraphPlot = "~0.6.0"
Graphs = "~1.9.0"
MLJBase = "~1.7.0"
OneHotArrays = "~0.2.5"
Plots = "~1.40.8"
PlutoUI = "~0.7.60"
Tables = "~1.12.0"
"""
# ╔═╡ 00000000-0000-0000-0000-000000000002
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised
julia_version = "1.11.1"
manifest_format = "2.0"
project_hash = "647c126887a2ffbf11df74e3cf03bb0463ccc843"
[[deps.AbstractPlutoDingetjes]]
deps = ["Pkg"]
git-tree-sha1 = "6e1d2a35f2f90a4bc7c2ed98079b2ba09c35b83a"
uuid = "6e696c72-6542-2067-7265-42206c756150"
version = "1.3.2"