Skip to content

Commit

Permalink
Replace "Learning PyTorch with Examples" with fitting sine function w…
Browse files Browse the repository at this point in the history
…ith a third order polynomial (#1265)

* Replace tutorial with fitting sine function with third order polynomial

* more

* Save

* save

* save

* fix

* fix

* fix

* fix

* no tensor.data

* fix

* P3

* save

* save

* save

* save

* fix

* fix

* fix
  • Loading branch information
zasdfgbnm authored Dec 3, 2020
1 parent 133e5b6 commit ce58d59
Show file tree
Hide file tree
Showing 13 changed files with 416 additions and 465 deletions.
72 changes: 72 additions & 0 deletions beginner_source/examples_autograd/polynomial_autograd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
"""
PyTorch: Tensors and autograd
-------------------------------
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
to :math:`pi` by minimizing squared Euclidean distance.
This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.
A PyTorch Tensor represents a node in a computational graph. If ``x`` is a
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor
holding the gradient of ``x`` with respect to some scalar value.
"""
import torch
import math

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y using operations on Tensors.
y_pred = a + b * x + c * x ** 2 + d * x ** 3

# Compute and print loss using operations on Tensors.
# Now loss is a Tensor of shape (1,)
# loss.item() gets the scalar value held in the loss.
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())

# Use autograd to compute the backward pass. This call will compute the
# gradient of loss with respect to all Tensors with requires_grad=True.
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
# the gradient of the loss with respect to a, b, c, d respectively.
loss.backward()

# Manually update weights using gradient descent. Wrap in torch.no_grad()
# because weights have requires_grad=True, but we don't need to track this
# in autograd.
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad

# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None

print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
104 changes: 104 additions & 0 deletions beginner_source/examples_autograd/polynomial_custom_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
"""
PyTorch: Defining New autograd Functions
----------------------------------------
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
to :math:`pi` by minimizing squared Euclidean distance. Instead of writing the
polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as
:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\frac{1}{2}\left(5x^3-3x\right)` is
the `Legendre polynomial`_ of degree three.
.. _Legendre polynomial:
https://en.wikipedia.org/wiki/Legendre_polynomials
This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.
In this implementation we implement our own custom autograd function to perform
:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\frac{3}{2}\left(5x^2-1\right)`
"""
import torch
import math


class LegendrePolynomial3(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
"""

@staticmethod
def forward(ctx, input):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
ctx.save_for_backward(input)
return 0.5 * (5 * input ** 3 - 3 * input)

@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
input, = ctx.saved_tensors
return grad_output * 1.5 * (5 * input ** 2 - 1)


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Create random Tensors for weights. For this example, we need
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
# not too far from the correct result to ensure convergence.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)

learning_rate = 5e-6
for t in range(2000):
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
P3 = LegendrePolynomial3.apply

# Forward pass: compute predicted y using operations; we compute
# P3 using our custom autograd operation.
y_pred = a + b * P3(c + d * x)

# Compute and print loss
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())

# Use autograd to compute the backward pass.
loss.backward()

# Update weights using gradient descent
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad

# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None

print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')
79 changes: 0 additions & 79 deletions beginner_source/examples_autograd/tf_two_layer_net.py

This file was deleted.

81 changes: 0 additions & 81 deletions beginner_source/examples_autograd/two_layer_net_autograd.py

This file was deleted.

Loading

0 comments on commit ce58d59

Please sign in to comment.