This repository is the official implementation of Whittle Networks introduced in Whittle Networks: A Deep Likelihood Model for Time Series by Zhongjie Yu, Fabrizio Ventola, and Kristian Kersting, published at ICML 2021.
This will clone the repo, install a Python virtual env (requires Python 3.6), and the required packages.
git clone https://github.com/ml-research/WhittleNetworks.git
./setup.sh
Download datasets from TU Datalib, and unzip:
wget https://tudatalib.ulb.tu-darmstadt.de/bitstream/handle/tudatalib/2887/data.zip
unzip data.zip
source ./venv_wnet/bin/activate
./run_WSPN.sh Sine
"Sine" can be replaced with "MNIST", "SP", "Stock", or "Billiards".
This will train and evaluate WSPNs with 1d, pair, and 2d Gaussian leaf nodes. Details can be found in Table 1 in our paper.
python script_graph.py --data_type=Sine --graph_type=bn
"Sine" can be replaced with "SP", "Stock", or "VAR".
--graph_type
can be either "bn" -- directed graph, or "mn" -- undirected graph.
Bayesian information criterion will be enabled with --BIC
Pre-trained WSPN models are in results/
python script_wcspn.py
python train_WhittleAE.py
python test_WhittleAE.py
If you find this code useful in your research, please consider citing:
@inproceedings{yu2021wspn,
title = {Whittle Networks: A Deep Likelihood Model for Time Series},
author = {Yu, Zhongjie and Ventola, Fabrizio and Kersting, Kristian},
booktitle = { Proceedings of the International Conference on Machine Learning (ICML) },
pages = {12177--12186},
year = {2021}
}
- This work is supported by the Federal Ministry of Education and Research (BMBF; project "MADESI", FKZ 01IS18043B, and Competence Center for AI and Labour; "kompAKI", FKZ 02L19C150), the German Science Foundation (DFG, German Research Foundation; GRK 1994/1 "AIPHES"), the Hessian Ministry of Higher Education, Research, Science and the Arts (HMWK; projects "The Third Wave of AI" and "The Adaptive Mind"), and the Hessian research priority programme LOEWE within the project "WhiteBox".