diff --git a/Ch12_Optimization_Algorithms/RMSProp.ipynb b/Ch12_Optimization_Algorithms/RMSProp.ipynb new file mode 100644 index 00000000..db613f02 --- /dev/null +++ b/Ch12_Optimization_Algorithms/RMSProp.ipynb @@ -0,0 +1,1712 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# RMSProp\n", + "\n", + "In the experiment in the Adagrad section, the learning rate of each element in the independent variable\n", + "of the objective function declines (or remains unchanged) during iteration because the variable $s_t$ in\n", + "the denominator is increased by the square by element operation of the mini-batch stochastic gradient,\n", + "adjusting the learning rate. Therefore, when the learning rate declines very fast during early iteration, yet\n", + "the current solution is still not desirable, Adagrad might have difficulty finding a useful solution because\n", + "the learning rate will be too small at later stages of iteration. To tackle this problem, the RMSProp\n", + "algorithm made a small modification to Adagrad.\n", + "\n", + "## 8.6.1 The Algorithm\n", + "\n", + "We introduced EWMA (exponentially weighted moving average) in the Momentum section. Unlike in\n", + "Adagrad, the state variable $s_t$ is the sum of the square by element all the mini-batch stochastic gradients\n", + "$g_t$ up to the time step t, RMSProp uses the EWMA on the square by element results of these gradients.\n", + "Specifically, given the hyperparameter 0 ≤ $ \\gamma $ < 1, RMSProp is computed at time step t > 0.\n", + "\n", + "$$ \\begin{aligned} \\mathbf{s}_t \\leftarrow \\gamma \\mathbf{s}_{t-1} + (1 - \\gamma) \\mathbf{g}_t * \\mathbf{g}_t \\end{aligned} $$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Like Adagrad, RMSProp re-adjusts the learning rate of each element in the independent variable of the\n", + "objective function with element operations and then updates the independent variable.\n", + "\n", + "$$ \\begin{aligned} \\mathbf{x}_t \\leftarrow \\mathbf{x}_{t-1} (\\frac{\\eta}{\\sqrt{\\mathbf{s}_t + \\epsilon}}) * \\mathbf{g}_t \\end{aligned} $$ \n", + "\n", + "Here, η is the learning rate while ε is a constant added to maintain numerical stability, such as $10 ^ {−6}$ .\n", + "Because the state variable of RMSProp is an EWMA of the squared term $g_t * g_t$ , it can be seen as the\n", + "weighted average of the mini-batch stochastic gradient’s squared terms from the last 1/(1 − $ \\gamma $) time steps.\n", + "Therefore, the learning rate of each element in the independent variable will not always decline (or remain\n", + "unchanged) during iteration.\n", + "\n", + "By convention, we will use the objective function f (x) = 0.1x 21 + 2x 22 to observe the iterative trajectory\n", + "of the independent variable in RMSProp. Recall that in the Adagrad section, when we used Adagrad with\n", + "a learning rate of 0.4, the independent variable moved less in later stages of iteration. However, at the\n", + "same learning rate, RMSProp can approach the optimal solution faster." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 20, x1 -0.010599, x2 0.000000\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.insert(0, '..')\n", + "\n", + "import d2l\n", + "import math\n", + "import torch\n", + "\n", + "def rmsprop_2d(x1, x2, s1, s2):\n", + " g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6\n", + " s1 = gamma * s1 + (1 - gamma) * g1 ** 2\n", + " s2 = gamma * s2 + (1 - gamma) * g2 ** 2\n", + " x1 -= eta / math.sqrt(s1 + eps) * g1\n", + " x2 -= eta / math.sqrt(s2 + eps) * g2\n", + " return x1, x2, s1, s2\n", + "\n", + "def f_2d(x1, x2):\n", + " return 0.1 * x1 ** 2 + 2 * x2 ** 2\n", + "eta, gamma = 0.4, 0.9\n", + "d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.6.2 Implementation from Scratch\n", + "\n", + "Next, we implement RMSProp with the formula in the algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def init_rmsprop_states():\n", + " s_w = torch.zeros((features.shape[1], 1))\n", + " s_b = torch.zeros(1)\n", + " return (s_w, s_b)\n", + "\n", + "def rmsprop(params, states, hyperparams):\n", + " gamma, eps = hyperparams['gamma'], 1e-6\n", + " for p, s in zip(params, states):\n", + " s[:] = gamma * s + (1 - gamma) * p.grad**2\n", + " p[:] -= hyperparams['lr'] * p.grad / (s + eps).sqrt()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We set the initial learning rate to 0.01 and the hyperparameter $ \\gamma $ to 0.9. Now, the variable $s_t$ can be treated\n", + "as the weighted average of the square term $g_t$ ⊙ $g_t$ from the last 1/(1 − 0.9) = 10 time steps." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.6.3 Concise Implementation\n", + "\n", + "From the *Trainer* instance of the algorithm named rmsprop, we can implement the **RMSProp** algorithm\n", + "with Gluon to train models. Note that the hyperparameter $ \\gamma $ is assigned by *gamma1*." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 0.242, 0.012 sec/epoch\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data_iter, feature_dim = d2l.get_data_ch10(batch_size=10)\n", + "\n", + "d2l.train_ch10(torch.optim.RMSprop, {'lr': 0.01, 'gamma': 0.9}, data_iter, feature_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "* The difference between RMSProp and Adagrad is that RMSProp uses an EWMA on the squares of elements in the mini-batch stochastic gradient to adjust the learning rate.\n", + "\n", + "## Exercises\n", + "\n", + "* What happens to the experimental results if we set the value of $γ$ to 1? Why?\n", + "\n", + "* Try using other combinations of initial learning rates and γ hyperparameters and observe and ana-lyze the experimental results.\n", + "\n", + "## Reference\n", + "\n", + "[1] Tieleman, T., & Hinton, G. (2012). Lecture 6.5-rmsprop: Divide the gradient by a running average of\n", + "its recent magnitude. COURSERA: Neural networks for machine learning, 4(2), 26-31." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/d2l/train.py b/d2l/train.py index a347323a..8566d4fd 100644 --- a/d2l/train.py +++ b/d2l/train.py @@ -323,7 +323,11 @@ def train_ch10(trainer, hyperparams, data_iter, feature_dim, num_epochs=2): w = Variable(torch.from_numpy(w1), requires_grad=True) b = Variable(torch.from_numpy(b1), requires_grad=True) - optimizer = trainer([w, b], lr=hyperparams['lr'], momentum=hyperparams['momentum']) + if trainer.__name__ == 'SGD': + optimizer = trainer([w, b], lr=hyperparams['lr'], momentum=hyperparams['momentum']) + elif trainer.__name__ == 'RMSprop': + optimizer = trainer([w, b], lr=hyperparams['lr'], alpha=hyperparams['gamma']) + net, loss = lambda X: linreg(X, w, b), squared_loss # Train animator = Animator(xlabel='epoch', ylabel='loss', @@ -345,4 +349,4 @@ def train_ch10(trainer, hyperparams, data_iter, feature_dim, num_epochs=2): evaluate_loss(net, data_iter, loss)) timer.start() print('loss: %.3f, %.3f sec/epoch'%(animator.Y[0][-1], timer.avg())) - return timer.cumsum(), animator.Y[0] \ No newline at end of file + # return timer.cumsum(), animator.Y[0] \ No newline at end of file