Skip to content

Commit

Permalink
Merge pull request #9 from liyaguang/v2
Browse files Browse the repository at this point in the history
Upgrading to DCRNN v2.
  • Loading branch information
liyaguang authored Oct 1, 2018
2 parents 80e156c + 9520e6c commit 2e4b8c8
Show file tree
Hide file tree
Showing 40 changed files with 1,066 additions and 1,400 deletions.
33 changes: 18 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent


## Requirements
- hyperopt>=0.1
- scipy>=0.19.0
- numpy>=1.12.1
- pandas>=0.19.2
Expand All @@ -22,34 +21,38 @@ Dependency can be installed using the following command:
pip install -r requirements.txt
```


## Traffic Data
## Data Preparation
The traffic data file for Los Angeles, i.e., `df_highway_2012_4mon_sample.h5`, is available [here](https://drive.google.com/open?id=1tjf5aXCgUoimvADyxKqb-YUlxP8O46pb), and should be
put into the `data/` folder.
put into the `data/METR-LA` folder.
Besides, the locations of sensors are available at [data/sensor_graph/graph_sensor_locations.csv](https://github.com/liyaguang/DCRNN/blob/master/data/sensor_graph/graph_sensor_locations.csv).
```bash
python -m scripts.generate_training_data --output_dir=data/METR-LA
```
The generated train/val/test dataset will be saved at `data/METR-LA/{train,val,test}.npz`.

## Graph Construction
As the currently implementation is based on pre-calculated road network distances between sensors, it currently only
supports sensor ids in Los Angeles (see `data/sensor_graph/sensor_info_201206.csv`).

## Run the Pre-trained Model

```bash
python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
--output_pkl_filename=data/sensor_graph/adj_mx.pkl
python run_demo.py
```
The generated prediction of DCRNN is in `data/results/dcrnn_predictions_[1-12].h5`.

## Train the Model

## Model Training
```bash
python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml
```
Each epoch takes about 5min with a single GTX 1080 Ti. There is a chance that train/val loss will explode, gradient explosion,


## Run the Pre-trained Model
## Graph Construction
As the currently implementation is based on pre-calculated road network distances between sensors, it currently only
supports sensor ids in Los Angeles (see `data/sensor_graph/sensor_info_201206.csv`).

```bash
python run_demo.py
python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
--output_pkl_filename=data/sensor_graph/adj_mx.pkl
```
The generated prediction of DCRNN is in `data/results/dcrnn_predictions_[1-12].h5`.


More details are being added ...

Expand Down

This file was deleted.

Binary file not shown.
Binary file not shown.
68 changes: 36 additions & 32 deletions data/model/dcrnn_config.yaml
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
---
base_dir: data/model
batch_size: 64
cl_decay_steps: 2000
data_type: ALL
dropout: 0
epoch: 0
epochs: 100
filter_type: dual_random_walk
global_step: 0
graph_pkl_filename: data/sensor_graph/adj_mx.pkl
horizon: 12
l1_decay: 0
learning_rate: 0.01
loss_func: MAE
lr_decay: 0.1
lr_decay_epoch: 20
lr_decay_interval: 10
max_diffusion_step: 2
max_grad_norm: 5
min_learning_rate: 2.0e-06
null_val: 0
num_rnn_layers: 2
output_dim: 1
patience: 50
rnn_units: 64
seq_len: 12
test_every_n_epochs: 10
test_ratio: 0.2
use_cpu_only: false
use_curriculum_learning: true
validation_ratio: 0.1
verbose: 0
write_db: false
data:
batch_size: 64
dataset_dir: data/METR-LA
test_batch_size: 64
val_batch_size: 64
graph_pkl_filename: data/sensor_graph/adj_mx.pkl

model:
cl_decay_steps: 2000
filter_type: dual_random_walk
horizon: 12
input_dim: 2
l1_decay: 0
max_diffusion_step: 2
num_nodes: 207
num_rnn_layers: 2
output_dim: 1
rnn_units: 64
seq_len: 12
use_curriculum_learning: true

train:
base_lr: 0.01
dropout: 0
epoch: 0
epochs: 100
epsilon: 1.0e-3
global_step: 0
lr_decay_ratio: 0.1
max_grad_norm: 5
max_to_keep: 100
min_learning_rate: 2.0e-06
optimizer: adam
patience: 50
steps: [20, 30, 40, 50]
test_every_n_epochs: 10
36 changes: 36 additions & 0 deletions data/model/dcrnn_config_u16_lap.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
base_dir: data/model
data:
batch_size: 64
dataset_dir: data/METR-LA
test_batch_size: 64
val_batch_size: 64
graph_pkl_filename: data/sensor_graph/adj_mx.pkl

model:
cl_decay_steps: 2000
filter_type: laplacian
horizon: 12
input_dim: 2
l1_decay: 0
max_diffusion_step: 2
max_grad_norm: 5
num_nodes: 207
num_rnn_layers: 2
output_dim: 1
rnn_units: 16
seq_len: 12
use_curriculum_learning: true

train:
base_lr: 0.01
dropout: 0
epoch: 0
epochs: 100
global_step: 0
lr_decay_ratio: 0.1
steps: [20, 30, 40, 50]
max_to_keep: 100
min_learning_rate: 2.0e-06
patience: 50
test_every_n_epochs: 10
67 changes: 34 additions & 33 deletions data/model/dcrnn_test_config.yaml
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
---
base_dir: data/model
batch_size: 64
cl_decay_steps: 2000
data_type: ALL
dropout: 0
epoch: 0
epochs: 100
filter_type: random_walk
global_step: 0
graph_pkl_filename: data/sensor_graph/adj_mx.pkl
horizon: 3
l1_decay: 0
learning_rate: 0.01
loss_func: MAE
lr_decay: 0.1
lr_decay_epoch: 20
lr_decay_interval: 10
max_diffusion_step: 2
max_grad_norm: 5
method_type: GCRNN
min_learning_rate: 2.0e-06
null_val: 0
num_rnn_layers: 2
output_dim: 1
patience: 50
rnn_units: 16
seq_len: 3
test_every_n_epochs: 10
test_ratio: 0.2
use_cpu_only: false
use_curriculum_learning: true
validation_ratio: 0.1
verbose: 0
write_db: false
data:
batch_size: 64
dataset_dir: data/METR-LA
test_batch_size: 64
val_batch_size: 64
graph_pkl_filename: data/sensor_graph/adj_mx.pkl

model:
cl_decay_steps: 2000
filter_type: dual_random_walk
horizon: 12
input_dim: 2
l1_decay: 0
max_diffusion_step: 2
max_grad_norm: 5
num_nodes: 207
num_rnn_layers: 2
output_dim: 1
rnn_units: 64
seq_len: 12
use_curriculum_learning: true

train:
base_lr: 0.01
dropout: 0
epoch: 0
epochs: 100
global_step: 0
lr_decay_ratio: 0.1
steps: [20, 30, 40, 50]
max_to_keep: 100
min_learning_rate: 2.0e-06
patience: 50
test_every_n_epochs: 10
40 changes: 40 additions & 0 deletions data/model/pretrained/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
base_dir: data/model
data:
batch_size: 64
dataset_dir: data/METR-LA
graph_pkl_filename: data/sensor_graph/adj_mx.pkl
test_batch_size: 64
model:
cl_decay_steps: 2000
filter_type: dual_random_walk
horizon: 12
input_dim: 2
l1_decay: 0
max_diffusion_step: 2
num_nodes: 207
num_rnn_layers: 2
output_dim: 1
rnn_units: 64
seq_len: 12
use_curriculum_learning: true
train:
base_lr: 0.01
dropout: 0
epoch: 64
epochs: 100
epsilon: 0.001
global_step: 24375
log_dir: data/model/pretrained/
lr_decay_ratio: 0.1
max_grad_norm: 5
max_to_keep: 100
min_learning_rate: 2.0e-06
model_filename: data/model/pretrained/models-2.7422-24375
optimizer: adam
patience: 50
steps:
- 20
- 30
- 40
- 50
test_every_n_epochs: 10
Binary file not shown.
Binary file added data/model/pretrained/models-2.7422-24375.index
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_1.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_10.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_11.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_12.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_2.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_3.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_4.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_5.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_6.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_7.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_8.h5
Binary file not shown.
Binary file removed data/results/dcrnn_prediction_9.h5
Binary file not shown.
Loading

0 comments on commit 2e4b8c8

Please sign in to comment.