Skip to content

Latest commit

 

History

History
41 lines (25 loc) · 2.73 KB

README.md

File metadata and controls

41 lines (25 loc) · 2.73 KB

Recurrent Neural Network with LSTM Cells, in pure Python

A vanilla implementation of a Recurrent Neural Network (RNN) with Long-Short-Term-Memory cells, without using any ML libraries.

Background

These networks are particularly good for learning long-term dependencies within data, and can be applied to a variety of problems including language modelling, translation and speech recognition.

An LSTM cell has 4 gates, based on the following formulas:

Each gate has it's own set of paramaters to learn, which makes training vanilla implementations (such as this one) expensive.

These are collected into a single cell state value:

This is then given to a hidden state, as a normal RNN cell would: LSTM cells can effectively be treated no differently to any other cell within the network.

Training and Initialisation

To initialise the network, create an instance of the class by calling the constructor with the arguments:

rnn = new LSTM_RNN(lr, in_dim, h_dim, out_dim)

Where lr is the learning rate; in_dim is the dimension of the input layer; h_dim is the dimension of the hidden layer and out_dim is the dimension of the output layer. These should correspond to your training data.

The training data should be encoded as integers, and given as two lists: a list of inputs and a corresponding one of targets. The RNN can then be trained by calling the function:

rnn.train(iterations, inputs, targets, seq_len)

Where iterations is the number of iterations to run, inputs and targets are the training data, and seq_len is the length of each batch of data.

Planned Features and Improvements

  • A sampling method to view the output of the network as it is training, using a forward pass.
  • Refactor the code to use a graph of computation model.
  • Use a linear sigmoid function to improve the speed.