Skip to content

Implementation of a Recurrent Neural Network with LSTM cells, in pure Python.

Notifications You must be signed in to change notification settings

tompntn/LSTM-RNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 

Repository files navigation

Recurrent Neural Network with LSTM Cells, in pure Python

A vanilla implementation of a Recurrent Neural Network (RNN) with Long-Short-Term-Memory cells, without using any ML libraries.

Background

These networks are particularly good for learning long-term dependencies within data, and can be applied to a variety of problems including language modelling, translation and speech recognition.

An LSTM cell has 4 gates, based on the following formulas:

Each gate has it's own set of paramaters to learn, which makes training vanilla implementations (such as this one) expensive.

These are collected into a single cell state value:

This is then given to a hidden state, as a normal RNN cell would: LSTM cells can effectively be treated no differently to any other cell within the network.

Training and Initialisation

To initialise the network, create an instance of the class by calling the constructor with the arguments:

rnn = new LSTM_RNN(lr, in_dim, h_dim, out_dim)

Where lr is the learning rate; in_dim is the dimension of the input layer; h_dim is the dimension of the hidden layer and out_dim is the dimension of the output layer. These should correspond to your training data.

The training data should be encoded as integers, and given as two lists: a list of inputs and a corresponding one of targets. The RNN can then be trained by calling the function:

rnn.train(iterations, inputs, targets, seq_len)

Where iterations is the number of iterations to run, inputs and targets are the training data, and seq_len is the length of each batch of data.

Planned Features and Improvements

  • A sampling method to view the output of the network as it is training, using a forward pass.
  • Refactor the code to use a graph of computation model.
  • Use a linear sigmoid function to improve the speed.

About

Implementation of a Recurrent Neural Network with LSTM cells, in pure Python.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages