Add PatchTST as an imputation model #431
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: CI | |
on: | |
push: | |
branches: | |
- main | |
- dev | |
- temp_test_branch # if in need, create such a temporary branch to test some functions | |
pull_request: | |
branches: | |
- dev | |
jobs: | |
CI-testing: | |
runs-on: ${{ matrix.os }} | |
defaults: | |
run: | |
shell: bash | |
strategy: | |
fail-fast: false | |
matrix: | |
os: [ubuntu-latest, windows-latest, macOS-latest] | |
python-version: ["3.7", "3.11"] | |
steps: | |
- name: Check out the repo code | |
uses: actions/checkout@v3 | |
- name: Determine the PyTorch version | |
uses: haya14busa/action-cond@v1 | |
id: determine_pytorch_ver | |
with: | |
cond: ${{ matrix.python-version == 3.7 }} | |
if_true: "1.13.1" | |
if_false: "2.1.0" | |
- name: Set up Python | |
uses: actions/setup-python@v4 | |
with: | |
python-version: ${{ matrix.python-version }} | |
check-latest: true | |
cache: pip | |
cache-dependency-path: | | |
setup.cfg | |
- name: Install PyTorch ${{ steps.determine_pytorch_ver.outputs.value }}+cpu | |
# we have to install torch in advance because torch_sparse needs it for compilation, | |
# refer to https://github.com/rusty1s/pytorch_sparse/issues/156#issuecomment-1304869772 for details | |
run: | | |
which python | |
which pip | |
pip install torch==${{ steps.determine_pytorch_ver.outputs.value }} -f https://download.pytorch.org/whl/cpu | |
python -c "import torch; print('PyTorch:', torch.__version__)" | |
- name: Install other dependencies | |
run: | | |
pip install -r requirements.txt | |
pip install torch-geometric torch-scatter torch-sparse -f "https://data.pyg.org/whl/torch-${{ steps.determine_pytorch_ver.outputs.value }}+cpu.html" | |
pip install pypots[dev] | |
python_site_path=`python -c "import site; print(site.getsitepackages()[0])"` | |
echo "python site-packages path: $python_site_path" | |
rm -rf $python_site_path/pypots | |
python -c "import shutil;import site;shutil.copytree('pypots',site.getsitepackages()[0]+'/pypots')" | |
- name: Fetch the test environment details | |
run: | | |
which python | |
pip list | |
- name: Test with pytest | |
run: | | |
python tests/global_test_config.py | |
rm -rf testing_results && rm -rf tests/__pycache__ && rm -rf tests/*/__pycache__ | |
python -m pytest -rA tests/*/* -s -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc | |
- name: Generate the LCOV report | |
run: | | |
python -m coverage lcov | |
- name: Submit the report | |
uses: coverallsapp/github-action@master | |
with: | |
github-token: ${{ secrets.GITHUB_TOKEN }} | |
path-to-lcov: "coverage.lcov" |