A pytorch implementation for the paper: UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction.
Yuan Yuan, Jingtao Ding, Jie Feng, Depeng Jin, Yong Li
The repo currently includes code implementations for the following tasks:
Short-term Prediction: We provide all scripts for the reproduction of short-term prediction results in this repo.
Long-term Prediction: We provide all scripts for the reproduction of long-term prediction results in this repo.
Few-shot Prediction: UniST can generalize well to scenarios with limited training data, making it to be data-efficient.
Zero-shot Prediction: UniST is demonstrated to generalize well on unseen spatio-temporal scenarios, making it a nice alternative as the fundamental backbone of the foundation spatio-temporal model.
📢: News (2024.06) Introduction of our work in 量子位, 时空探索之旅, 时序人 are available.
📢: News (2024.05) UniST has been accepted to KDD 2024.
🏆 By capturing the underlying commonalities across multiple spatio-temporal scenarios, UniST breaks the conventional practice that train separate models for different datasets, and has demonstrated superior performance and powerful generalization capability across diverse urban scenarios.
🌟 The training of UniST consists of two stages: (i) large-scale spatio-temporal pre-training, and (ii) spatio-temporal knowledge-guided prompt tuning.
The pseudo-code of UniST is as simple as the following:
Model | Data Format | Data Scalability | Few-shot | Zero-shot | Computation Cost | Memory Cost |
---|---|---|---|---|---|---|
PromptST [1] | Grid | ✗ | ✗ | ✗ | Low | Low |
GPT-ST [2] | Graph | ✗ | ✗ | ✗ | Low | Low |
STEP [3] | Graph | ✗ | ✗ | ✗ | Low | Low |
ST-SSL [4] | Graph | ✗ | ✗ | ✗ | Low | Low |
TrafficBERT [5] | Grid/Graph | ✓ | ✗ | ✗ | Low | Low |
TFM [6] | Graph | ✗ | ✗ | ✗ | Low | Low |
UrbanGPT [7] | Grid | ✓(a) | ✓(a) | ✓(a) | High | High |
STG-LLM [8] | Graph | ✗ | ✗ | ✗ | High | High |
UniST | Grid/Graph | ✓ | ✓ | ✓ | Low | Low |
(a). Still restricted in the same city.
[1] PromptST: Prompt-Enhanced Spatio-Temporal Multi-Attribute Prediction, CIKM 2023
[2] GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks, NIPS 2023
[3] Pre-training enhanced spatial-temporal graph neural network for multivariate time series forecasting, KDD 2022
[4] Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction, AAAI 2023
[5] TrafficBERT: Pre-trained model with large-scale data for long-range traffic flow forecasting, Expert Systems with Applications
[6] Building transportation foundation model via generative graph transformer, ITSC 2023
[7] UrbanGPT: Spatio-Temporal Large Language Models, KDD 2024
[8] How can large language models understand spatial-temporal data?, arXiv 2024
We use multiple datasets to demonstrate the UniST, which span various cities and domains. To access the datasets, please refer to data readme.
- Tested OS: Linux
- Python >= 3.9
- torch == 2.0.0
- Tensorboard
- Install Pytorch with the correct CUDA version.
- Use the
pip install -r requirements.txt
command to install all of the Python modules and packages used in this project.
Please first navigate to the src
directory by using the cd command: cd src
Then please create a folder named experiments
to record the training process: mkdir experiments
We provide the scripts under the folder ./scripts/pretrain.sh
. You can train UniST with the Cellular dataset as the following examples:
python main.py --device_id 3 --machine machine --dataset Crowd --task short --size middle --mask_strategy_random 'batch' --lr 3e-4 --used_data 'single' --prompt_ST 0
Once your model is trained, you will find the logs recording the training process in the ./logs/
directory. The folder will be named as the Pretrain_Dataset_<dataset>_task_<task>
. In the ./experiments/Pretrain_Dataset_<dataset>_task_<task>/model_save/
, you will find the trained model named model_best.pkl
.
In our experiments, we leverage multiple datasets to enhance UniST.
If you need to use multiple datasets, please use an asterisk (*) to separate the datasets, e.g., --dataset Crowd*Cellular*TaxiNYC*TaxiBike*TrafficSH
.
We provide the scripts under the folder ./scripts/prompt_tuning.sh
. You can fine-tune UniST with the Cellular dataset as the following examples:
python main.py --device_id 2 --machine machine --task short --size middle --prompt_ST 1 --pred_len 6 --his_len 6 --num_memory_spatial 512 --num_memory_temporal 512 --prompt_content 's_p_c' --dataset Crowd --lr 3e-4 --used_data 'single' --file_load_path pretrained_model_path
There are some new parameters to specify:
his_len
specifies the input sequence length.pred_len
specifies the prediction horizon.file_load_path
specifies the save path of the pre-trained model, the default is./experiments/Dataset_<dataset>_task_<task>/model_save/model_best.pkl
num_memory_spatial
andnum_memory_temporal
specify the number of embeddings in the memory pools.prompt_ST
specifies whether perform prompt-tuning: 0 for no prompt and 1 for prompt-tuning.prompt_content
specifies the type of prompt, which can be selected from ['s_p_c','s','c','p','s_c','s_p','p_c'].
Once your model is trained, you will find the logs recording the training process in the ./logs/
directory. The folder will be named as the Prompt_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>
. In the ./experiments/Prompt_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>/model_save/
, you will find the fine-tuned model named model_best.pkl
.
The evaluation results of the testing set can be obtained from ./experiments/Prompt_Mode_finetuning_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>/result.txt
.
We provide downloads of model weights on xxx. Coming soon.
If you find this repo helpful, please cite our paper.
@article{yuan2024unist,
title={UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction},
author={Yuan, Yuan and Ding, Jingtao and Feng, Jie and Jin, Depeng and Li, Yong},
journal={arXiv preprint arXiv:2402.11838},
year={2024}
}
We appreciate the following GitHub repos a lot for their valuable code and efforts.
- Spatio-temporal prediction benchmark: https://github.com/chengtan9907/OpenSTL
- Spatio-temporal data: https://github.com/aptx1231/NYC-Dataset
- MAE: https://github.com/facebookresearch/mae
- PatchTST: https://github.com/PatchTST/PatchTST
- iTransformer: https://github.com/thuml/iTransformer
If you have any questions or want to use the code, feel free to contact:
- Yuan Yuan (y-yuan20@mails.tsinghua.edu.cn)