Clean code repo for Higher-Order Tensor recurrent neural network (HOT-RNN), implemented in Tensorflow. See details in our paper Long-Term Forecasting with Tensor Train RNNs
install prerequisites
- tensorflow >= r1.6
- Python >=3.0
- Jupyter >=4.1.1
import module
from trnn import TensorLSTMCell
from trnn_imply import tensor_rnn_with_feed_prev
- TensorLSTMCell(num_units, num_lags, rank_vals) – creates a
TensorTrainLSTM
object withnum_units
hidden nodes,num_lags
time lags, withrank_vals
is the list of values for tensor train decomposition rank
- tensor_rnn_with_feed_prev – forward pass for a single
TensorTrainLSTM
cell, returns anoutput
and a hidden state.
Run the Jupyter notebook
jupyter notebook test_trnn.pynb
A simple example of using TensorTrainLSTM
by
- loading a set of
sim
sequences - building a tensor train Seq2Seq model
- making long-term predictions
-
reader.py read the data into train/valid/test datasets, normalize the data if needed
-
model.py seq2seq model for sequence prediction
-
trnn.py tensor-train lstm cell and corresponding tensor train contraction
-
trnn_imply.py forward step in tensor-train rnn, feed previous predictions as input
If you think the repo is useful, we kindly ask you to cite our work at
@article{yu2017long,
title={Long-term forecasting using tensor-train RNNs},
author={Yu, Rose and Zheng, Stephan and Anandkumar, Anima and Yue, Yisong},
journal={arXiv preprint arXiv:1711.00073},
year={2017}
}