-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace "Learning PyTorch with Examples" with fitting sine function w…
…ith a third order polynomial (#1265) * Replace tutorial with fitting sine function with third order polynomial * more * Save * save * save * fix * fix * fix * fix * no tensor.data * fix * P3 * save * save * save * save * fix * fix * fix
- Loading branch information
Showing
13 changed files
with
416 additions
and
465 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
PyTorch: Tensors and autograd | ||
------------------------------- | ||
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` | ||
to :math:`pi` by minimizing squared Euclidean distance. | ||
This implementation computes the forward pass using operations on PyTorch | ||
Tensors, and uses PyTorch autograd to compute gradients. | ||
A PyTorch Tensor represents a node in a computational graph. If ``x`` is a | ||
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor | ||
holding the gradient of ``x`` with respect to some scalar value. | ||
""" | ||
import torch | ||
import math | ||
|
||
dtype = torch.float | ||
device = torch.device("cpu") | ||
# device = torch.device("cuda:0") # Uncomment this to run on GPU | ||
|
||
# Create Tensors to hold input and outputs. | ||
# By default, requires_grad=False, which indicates that we do not need to | ||
# compute gradients with respect to these Tensors during the backward pass. | ||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) | ||
y = torch.sin(x) | ||
|
||
# Create random Tensors for weights. For a third order polynomial, we need | ||
# 4 weights: y = a + b x + c x^2 + d x^3 | ||
# Setting requires_grad=True indicates that we want to compute gradients with | ||
# respect to these Tensors during the backward pass. | ||
a = torch.randn((), device=device, dtype=dtype, requires_grad=True) | ||
b = torch.randn((), device=device, dtype=dtype, requires_grad=True) | ||
c = torch.randn((), device=device, dtype=dtype, requires_grad=True) | ||
d = torch.randn((), device=device, dtype=dtype, requires_grad=True) | ||
|
||
learning_rate = 1e-6 | ||
for t in range(2000): | ||
# Forward pass: compute predicted y using operations on Tensors. | ||
y_pred = a + b * x + c * x ** 2 + d * x ** 3 | ||
|
||
# Compute and print loss using operations on Tensors. | ||
# Now loss is a Tensor of shape (1,) | ||
# loss.item() gets the scalar value held in the loss. | ||
loss = (y_pred - y).pow(2).sum() | ||
if t % 100 == 99: | ||
print(t, loss.item()) | ||
|
||
# Use autograd to compute the backward pass. This call will compute the | ||
# gradient of loss with respect to all Tensors with requires_grad=True. | ||
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding | ||
# the gradient of the loss with respect to a, b, c, d respectively. | ||
loss.backward() | ||
|
||
# Manually update weights using gradient descent. Wrap in torch.no_grad() | ||
# because weights have requires_grad=True, but we don't need to track this | ||
# in autograd. | ||
with torch.no_grad(): | ||
a -= learning_rate * a.grad | ||
b -= learning_rate * b.grad | ||
c -= learning_rate * c.grad | ||
d -= learning_rate * d.grad | ||
|
||
# Manually zero the gradients after updating weights | ||
a.grad = None | ||
b.grad = None | ||
c.grad = None | ||
d.grad = None | ||
|
||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') |
104 changes: 104 additions & 0 deletions
104
beginner_source/examples_autograd/polynomial_custom_function.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
PyTorch: Defining New autograd Functions | ||
---------------------------------------- | ||
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi` | ||
to :math:`pi` by minimizing squared Euclidean distance. Instead of writing the | ||
polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as | ||
:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is | ||
the `Legendre polynomial`_ of degree three. | ||
.. _Legendre polynomial: | ||
https://en.wikipedia.org/wiki/Legendre_polynomials | ||
This implementation computes the forward pass using operations on PyTorch | ||
Tensors, and uses PyTorch autograd to compute gradients. | ||
In this implementation we implement our own custom autograd function to perform | ||
:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\frac{3}{2}\left(5x^2-1\right)` | ||
""" | ||
import torch | ||
import math | ||
|
||
|
||
class LegendrePolynomial3(torch.autograd.Function): | ||
""" | ||
We can implement our own custom autograd Functions by subclassing | ||
torch.autograd.Function and implementing the forward and backward passes | ||
which operate on Tensors. | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, input): | ||
""" | ||
In the forward pass we receive a Tensor containing the input and return | ||
a Tensor containing the output. ctx is a context object that can be used | ||
to stash information for backward computation. You can cache arbitrary | ||
objects for use in the backward pass using the ctx.save_for_backward method. | ||
""" | ||
ctx.save_for_backward(input) | ||
return 0.5 * (5 * input ** 3 - 3 * input) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
""" | ||
In the backward pass we receive a Tensor containing the gradient of the loss | ||
with respect to the output, and we need to compute the gradient of the loss | ||
with respect to the input. | ||
""" | ||
input, = ctx.saved_tensors | ||
return grad_output * 1.5 * (5 * input ** 2 - 1) | ||
|
||
|
||
dtype = torch.float | ||
device = torch.device("cpu") | ||
# device = torch.device("cuda:0") # Uncomment this to run on GPU | ||
|
||
# Create Tensors to hold input and outputs. | ||
# By default, requires_grad=False, which indicates that we do not need to | ||
# compute gradients with respect to these Tensors during the backward pass. | ||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) | ||
y = torch.sin(x) | ||
|
||
# Create random Tensors for weights. For this example, we need | ||
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized | ||
# not too far from the correct result to ensure convergence. | ||
# Setting requires_grad=True indicates that we want to compute gradients with | ||
# respect to these Tensors during the backward pass. | ||
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) | ||
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) | ||
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) | ||
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) | ||
|
||
learning_rate = 5e-6 | ||
for t in range(2000): | ||
# To apply our Function, we use Function.apply method. We alias this as 'P3'. | ||
P3 = LegendrePolynomial3.apply | ||
|
||
# Forward pass: compute predicted y using operations; we compute | ||
# P3 using our custom autograd operation. | ||
y_pred = a + b * P3(c + d * x) | ||
|
||
# Compute and print loss | ||
loss = (y_pred - y).pow(2).sum() | ||
if t % 100 == 99: | ||
print(t, loss.item()) | ||
|
||
# Use autograd to compute the backward pass. | ||
loss.backward() | ||
|
||
# Update weights using gradient descent | ||
with torch.no_grad(): | ||
a -= learning_rate * a.grad | ||
b -= learning_rate * b.grad | ||
c -= learning_rate * c.grad | ||
d -= learning_rate * d.grad | ||
|
||
# Manually zero the gradients after updating weights | ||
a.grad = None | ||
b.grad = None | ||
c.grad = None | ||
d.grad = None | ||
|
||
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)') |
This file was deleted.
Oops, something went wrong.
81 changes: 0 additions & 81 deletions
81
beginner_source/examples_autograd/two_layer_net_autograd.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.