diff --git a/.circleci/config.yml b/.circleci/config.yml old mode 100755 new mode 100644 index 28de0a75bdd123..c95f89ec36587d --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,191 +1,133 @@ -# Python CircleCI 2.0 configuration file -# -# Check https://circleci.com/docs/2.0/language-python/ for more details -# -version: 2.0 +# Python CircleCI 2.1 configuration file. +version: 2.1 +orbs: + gcp-gke: circleci/gcp-gke@1.0.4 + go: circleci/go@1.3.0 + codecov: codecov/codecov@1.1.0 references: - install_deps: &install_deps + make_docs: &make_docs run: - name: Install Dependences + name: Make Documentation command: | - sudo apt-get update && sudo apt-get install -y cmake - pip install "$TORCH_VERSION" - pip install -r requirements.txt -q - sudo pip install pytest pytest-cov pytest-flake8 -q - pip install -r ./tests/requirements-devel.txt -q - - tests: &tests + # First run the same pipeline as Read-The-Docs + # apt-get update && apt-get install -y cmake + # using: https://hub.docker.com/r/readthedocs/build + # we need to use py3.7 ot higher becase of an issue with metaclass inheritence + pyenv global 3.7.3 + python --version + pip install -r requirements/docs.txt + pip list + cd docs + make clean + make html --jobs 2 SPHINXOPTS="-W" + + checkout_ml_testing: &checkout_ml_testing run: - name: Testing + name: Checkout ml-testing-accelerators command: | - python --version ; pip --version ; pip list - py.test pytorch_lightning tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml - no_output_timeout: 30m + git clone https://github.com/GoogleCloudPlatform/ml-testing-accelerators.git + cd ml-testing-accelerators + git fetch origin 5e88ac24f631c27045e62f0e8d5dfcf34e425e25:stable + git checkout stable + cd .. - examples: &examples - run: - name: PL Examples - command: | - pip install -r ./pl_examples/requirements.txt --user - python --version ; pip --version ; pip list - py.test pl_examples -v --doctest-modules --junitxml=test-reports/pytest_junit.xml - no_output_timeout: 20m - - install_pkg: &install_pkg + build_push_docker: &build_push_docker run: - name: Install package + name: Build and push Docker image command: | - virtualenv vEnv ; source vEnv/bin/activate - pip install --editable . ; cd .. & python -c "import pytorch_lightning ; print(pytorch_lightning.__version__)" - deactivate ; rm -rf vEnv - - create_pkg: &create_pkg - run: - name: Create package - command: | - sudo pip install twine==1.13.0 - python setup.py sdist - twine check dist/* - python setup.py clean - - format: &format + gcloud --quiet auth configure-docker + #cd dockers/tpu-tests + export PYTHON_VER=$(python -c "import random ; print('3.6' if random.random() > 0.5 else '3.7')" 2>&1) + echo $PYTHON_VER + docker build --tag "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID" -f ./dockers/tpu-tests/Dockerfile --build-arg "PYTHON_VERSION=$PYTHON_VER" --build-arg "PYTORCH_VERSION=$XLA_VER" . + docker push "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID" + + deploy_cluster: &deploy_cluster run: - name: Formatting + name: Deploy the job on the kubernetes cluster command: | - python --version ; pip --version - sudo pip install flake8 -q - pip list - flake8 . - - make_docs: &make_docs + go get github.com/google/go-jsonnet/cmd/jsonnet + export PATH=$PATH:$HOME/go/bin + python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; fff = open(fname).read().replace('pytorch-VERSION', 'pytorch-$XLA_VER') ; open(fname, 'w').write(fff)" + job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet --ext-str image=$GCR_IMAGE_PATH --ext-str image-tag=$CIRCLE_WORKFLOW_JOB_ID | kubectl create -f -) + job_name=${job_name#job.batch/} + job_name=${job_name% created} + echo "Waiting on kubernetes job: $job_name" + i=0 && \ + # N checks spaced 30s apart = 900s total. + status_code=2 && \ + # Check on the job periodically. Set the status code depending on what + # happened to the job in Kubernetes. If we try MAX_CHECKS times and + # still the job hasn't finished, give up and return the starting + # non-zero status code. + printf "Waiting for job to finish: " && \ + while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \ + echo "Done waiting. Job status code: $status_code" && \ + pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \ + echo "GKE pod name: $pod_name" && \ + kubectl logs -f $pod_name --container=train > /tmp/full_output.txt + if grep -q '' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '//'; else mv /tmp/full_output.txt xx00; fi && \ + # First portion is the test logs. Print these to Github Action stdout. + cat xx00 && \ + echo "Done with log retrieval attempt." && \ + gcloud container images delete "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID" --force-delete-tags && \ + exit $status_code + + stats: &stats run: - name: Make Documentation + name: Statistics command: | - # sudo apt-get install pandoc - sudo apt-get update && sudo apt-get install -y cmake - pip install -r requirements.txt --user - sudo pip install -r docs/requirements.txt - pip install -r requirements-extra.txt --user # for doctesting loggers etc. - # sphinx-apidoc -o ./docs/source ./pytorch_lightning **/test_* --force --follow-links - cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W" - make doctest; make coverage + mv ./xx01 coverage.xml + # TODO: add human readable report + cat coverage.xml + sudo pip install pycobertura + pycobertura show coverage.xml jobs: - Build-Docs: - docker: - - image: circleci/python:3.7 - steps: - - checkout - - *make_docs - - store_artifacts: - # allows us to preview the generated html pages - path: docs/build/html/ - destination: html - - Formatting: + TPU-tests: docker: - image: circleci/python:3.7 environment: - - TORCH_VERSION: "torch" + - XLA_VER: 1.7 + - MAX_CHECKS: 240 + - CHECK_SPEEP: 5 steps: - checkout - - *format + - go/install + - *checkout_ml_testing + - gcp-gke/install + - gcp-gke/update-kubeconfig-with-credentials: + cluster: $GKE_CLUSTER + perform-login: true + - setup_remote_docker + - *build_push_docker + - *deploy_cluster + - *stats + - codecov/upload: + file: coverage.xml + flags: tpu,pytest + upload_name: TPU-coverage - PyTorch: - docker: - - image: circleci/python:3.6 - environment: - - TORCH_VERSION: "torch" - steps: &steps - - checkout - #- restore_cache: - # keys: - # # when lock file changes, use increasingly general patterns to restore cache - # - pip-packages--{{ .Environment.CIRCLE_JOB }} - # - pip-packages-- - - *install_deps - #- save_cache: - # key: pip-packages--{{ .Environment.CIRCLE_JOB }} - # paths: - # # this path depends on where pipenv creates a virtualenv - # - "~/.cache/pip" - # - "/usr/local/lib/python3.6/site-packages" - # - "/usr/local/lib/site-python" - - *tests - - store_test_results: - path: test-reports - store_artifacts: - path: test-reports - - PyTorch-v1_1: - docker: - - image: circleci/python:3.6 - environment: - - TORCH_VERSION: "torch>=1.1, <1.2" - steps: *steps - - PyTorch-v1_2: - docker: - - image: circleci/python:3.6 - environment: - - TORCH_VERSION: "torch>=1.2, <1.3" - steps: *steps - - PyTorch-v1_3: - docker: - - image: circleci/python:3.6 - environment: - - TORCH_VERSION: "torch>=1.3, <1.4" - steps: *steps - - PyTorch-v1_4: - docker: - - image: circleci/python:3.6 - environment: - - TORCH_VERSION: "torch>=1.4, <1.5" - steps: *steps - - PyTorch-v1_5: - docker: - - image: circleci/python:3.6 - environment: - - TORCH_VERSION: "torch>=1.5, <1.6" - steps: *steps + path: coverage.xml - Examples: + build-Docs: docker: - - image: circleci/python:3.7 - environment: - - TORCH_VERSION: "torch" - steps: - - checkout - - *install_deps - - *examples - - Install-pkg: - docker: - - image: circleci/python:3.7 + - image: readthedocs/build:latest steps: - checkout - - *create_pkg - - *install_pkg - -#orbs: -# python: circleci/python@0.2.1 + - *make_docs + - store_artifacts: + # allows us to preview the generated html pages + path: docs/build/html/ + destination: html workflows: version: 2 - build: + tpu-tests: jobs: - - Formatting - - Build-Docs - - PyTorch-v1_1 - - PyTorch-v1_2 - - PyTorch-v1_3 - - PyTorch-v1_4 - - PyTorch-v1_5 - - Install-pkg - - Examples + - build-Docs + - TPU-tests diff --git a/.codecov.yml b/.codecov.yml index 726a198d5a3a8a..cc6a5e6a2b7b3e 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # see https://docs.codecov.io/docs/codecov-yaml # Validation check: # $ curl --data-binary @.codecov.yml https://codecov.io/validate @@ -9,8 +23,10 @@ codecov: strict_yaml_branch: "yaml-config" require_ci_to_pass: yes notify: - # after_n_builds: 2 + after_n_builds: 23 wait_for_ci: yes + # https://docs.codecov.io/docs/codecov-yaml#section-expired-reports + max_report_age: off coverage: precision: 0 # 2 = xx.xx%, 0 = xx% @@ -48,5 +64,4 @@ comment: layout: header, diff require_changes: false behavior: default # update if exists else create new - # branches: * - + after_n_builds: 23 diff --git a/.drone.yml b/.drone.yml deleted file mode 100644 index 88e2d76a525032..00000000000000 --- a/.drone.yml +++ /dev/null @@ -1,58 +0,0 @@ -# https://docs.drone.io/pipeline/docker/examples/languages/python/#python-example - -kind: pipeline -type: docker -name: torch-GPU - -steps: -- name: testing - image: pytorchlightning/pytorch_lightning:devel-pt_1_4 - - environment: - SLURM_LOCALID: 0 - CODECOV_TOKEN: - from_secret: codecov_token - HOROVOD_GPU_ALLREDUCE: NCCL - HOROVOD_GPU_BROADCAST: NCCL - HOROVOD_WITH_PYTORCH: 1 - HOROVOD_WITHOUT_TENSORFLOW: 1 - HOROVOD_WITHOUT_MXNET: 1 - HOROVOD_WITH_GLOO: 1 - HOROVOD_WITHOUT_MPI: 1 - - #volumes: - # # Mount pip cache from host - # - name: pip_cache - # path: /opt/conda/lib/python3.7/site-packages - - commands: - - export PATH="$PATH:/root/.local/bin" - - python --version - - pip install pip -U - - pip --version - - nvidia-smi - #- bash ./tests/install_AMP.sh - - apt-get update && apt-get install -y cmake - - pip install -r requirements.txt --user -q - - pip install -r ./tests/requirements-devel.txt --user -q - #- pip install -r ./docs/requirements.txt --user -q - - pip list - - python -c "import torch ; print(' & '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) if torch.cuda.is_available() else 'only CPU')" - - coverage run --source pytorch_lightning -m py.test pytorch_lightning tests benchmarks -v --doctest-modules # --flake8 - #- cd docs; make doctest; make coverage - - coverage report - - codecov --token $CODECOV_TOKEN # --pr $DRONE_PULL_REQUEST --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG - - python tests/collect_env_details.py - -trigger: - branch: - - master - event: - include: - - push - - pull_request - -#volumes: -# - name: pip_cache -# host: -# path: /tmp/cache/drone/pip diff --git a/.github/BECOMING_A_CORE_CONTRIBUTOR.md b/.github/BECOMING_A_CORE_CONTRIBUTOR.md index 46db2111b0d1f6..828f45aedbecc8 100644 --- a/.github/BECOMING_A_CORE_CONTRIBUTOR.md +++ b/.github/BECOMING_A_CORE_CONTRIBUTOR.md @@ -1,17 +1,17 @@ # How to become a core contributor -Thanks for your interest in joining the Lightning team! We’re a rapidly growing project which is poised to become the go-to framework for DL researchers! -We're currently recruiting for a team of 5 core maintainers. +Thanks for your interest in joining the Lightning team! We’re a rapidly growing project which is poised to become the go-to framework for DL researchers! +We're currently recruiting for a team of 5 core maintainers. As a core maintainer you will have a strong say in the direction of the project. Big changes will require a majority of maintainers to agree. -### Code of conduct +### Code of conduct First and foremost, you'll be evaluated against [these core values](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md). Any code we commit or feature we add needs to align with those core values. -### The bar for joining the team +### The bar for joining the team Lightning is being used to solve really hard problems at the top AI labs in the world. As such, the bar for adding team members is extremely high. Candidates must have solid engineering skills, have a good eye for user experience, and must be a power user of Lightning and PyTorch. -With that said, the Lightning team will be diverse and a reflection of an inclusive AI community. You don't have to be an engineer to conntribute! Scientists with great usability intuition and PyTorch ninja skills are welcomed! +With that said, the Lightning team will be diverse and a reflection of an inclusive AI community. You don't have to be an engineer to contribute! Scientists with great usability intuition and PyTorch ninja skills are welcomed! ### Responsibilities: The responsibilities mainly revolve around 3 things. @@ -36,10 +36,10 @@ Pleasant/helpful tone. - Code is NOT overly engineered or hard to read - Ask yourself, could a non-engineer understand what’s happening here? - Make sure new tests are written -- Is this NECESSARY for Lightning? There are some PRs which are just purely about adding engineering complexity which have no place in Lightning. +- Is this NECESSARY for Lightning? There are some PRs which are just purely about adding engineering complexity which have no place in Lightning. Guidance - Some other PRs are for people who are wanting to get involved and add something unnecessary. We do want their help though! So don’t approve the PR, but direct them to a Github issue that they might be interested in helping with instead! -- To be considered for core contributor, please review 10 PRs and help the authors land it on master. Once you've finished the review, ping me +- To be considered for core contributor, please review 10 PRs and help the authors land it on master. Once you've finished the review, ping me for a sanity check. At the end of 10 PRs if your PR reviews are inline with expectations described above, then you can merge PRs on your own going forward, otherwise we'll do a few more until we're both comfortable :) @@ -47,13 +47,16 @@ otherwise we'll do a few more until we're both comfortable :) There are some big decisions which the project must make. For these I expect core contributors to have something meaningful to add if it’s their area of expertise. #### Diversity -Lightning should reflect the broader community it serves. As such we should have scientists/researchers from -different fields contributing! +Lightning should reflect the broader community it serves. As such we should have scientists/researchers from +different fields contributing! The first 5 core contributors will fit this profile. Thus if you overlap strongly with experiences and expertise as someone else on the team, you might have to wait until the next set of contributors are added. #### Summary: Requirements to apply -- Solve 10 Github issues. The goal is to be inline with expectations for solving issues by the last one so you can do them on your own. If not, I might ask you to solve a few more specific ones. -- Do 10 PR reviews. The goal is to be inline with expectations for solving issues by the last one so you can do them on your own. If not, I might ask you to solve a few more specific ones. +The goal is to be inline with expectations for solving issues by the last one so you can do them on your own. If not, I might ask you to solve a few more specific ones. -If you want to be considered, ping me on gitter and start [tracking your progress here](https://docs.google.com/spreadsheets/d/15D58gp8DvI0Z6qbbYVRuaWioiwzafcP58-UlbuO_CMU/edit?usp=sharing). +- Solve 10+ Github issues. +- Create 5+ meaningful PRs which solves some reported issue - bug, +- Perform 10+ PR reviews from other contributors. + +If you want to be considered, ping me on [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A). diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000000000..4ac6944c7a31ad --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,53 @@ +# This is a comment. +# Each line is a file pattern followed by one or more owners. + +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +# @global-owner1 and @global-owner2 will be requested for +# review when someone opens a pull request. +* @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock + +# CI/CD and configs +/.github/ @borda @tchaton +/dockers/ @borda @tchaton +*.yml @borda @tchaton + +# Docs +/docs/ @edenlightning @tchaton @borda @awaelchli +/.github/*.md @edenlightning @williamfalcon @borda +/.github/ISSUE_TEMPLATE/ @edenlightning @borda @tchaton +/docs/source/conf.py @borda @awaelchli + +# Packages +/pytorch_lightning/accelerators @williamfalcon @tchaton @SeanNaren @awaelchli @justusschock +/pytorch_lightning/callbacks @williamfalcon @tchaton @carmocca @borda +/pytorch_lightning/cluster_environments @borda @tchaton @SeanNaren @carmocca +/pytorch_lightning/core @tchaton @SeanNaren @borda @carmocca @justusschock +/pytorch_lightning/distributed @williamfalcon @tchaton @awaelchli +/pytorch_lightning/loggers @tchaton @awaelchli @borda +/pytorch_lightning/overrides @tchaton @SeanNaren @borda +/pytorch_lightning/plugins @tchaton @SeanNaren @awaelchli @justusschock +/pytorch_lightning/profiler @williamfalcon @tchaton @borda +/pytorch_lightning/trainer @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock +/pytorch_lightning/trainer/connectors @tchaton @SeanNaren @carmocca @borda +/pytorch_lightning/tuner @SkafteNicki @borda @awaelchli +/pytorch_lightning/utilities @borda @tchaton @SeanNaren @carmocca + +# Metrics +/pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock +/tests/metrics/ @SkafteNicki @ananyahjha93 @justusschock +/docs/source/metrics.rst @SkafteNicki @ananyahjha93 @justusschock + +# API +/pytorch_lightning/callbacks/base.py @williamfalcon +/pytorch_lightning/core/datamodule.py @williamFalcon +/pytorch_lightning/trainer/trainer.py @williamfalcon @tchaton +/pytorch_lightning/core/hooks.py @williamfalcon +/pytorch_lightning/core/lightning.py @williamfalcon @tchaton + +# Testing +/tests/helpers/boring_model.py @williamfalcon @tchaton @borda + +/.github/CODEOWNERS @williamfalcon +/README.md @williamfalcon @edenlightning @borda +/setup.py @williamfalcon @borda diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index f2cd71f3a83690..278cd72512e746 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -1,90 +1,136 @@ -# Contributing -Welcome to the PyTorch Lightning community! We're building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out! +# Contributing + +Welcome to the PyTorch Lightning community! We're building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out! ## Main Core Value: One less thing to remember Simplify the API as much as possible from the user perspective. - Any additions or improvements should minimize things the user needs to remember. +Any additions or improvements should minimize the things the user needs to remember. For example: One benefit of the validation_step is that the user doesn't have to remember to set the model to .eval(). - This avoids all sorts of subtle errors the user could make. +This helps users avoid all sorts of subtle errors. ## Lightning Design Principles -We encourage all sorts of contributions you're interested in adding! When coding for lightning, please follow these principles. - + +We encourage all sorts of contributions you're interested in adding! When coding for lightning, please follow these principles. + #### No PyTorch Interference + We don't want to add any abstractions on top of pure PyTorch. - This gives researchers all the control they need without having to learn yet another framework. +This gives researchers all the control they need without having to learn yet another framework. #### Simple Internal Code + It's useful for users to look at the code and understand very quickly what's happening. - Many users won't be engineers. Thus we need to value clear, simple code over condensed ninja moves. - While that's super cool, this isn't the project for that :) +Many users won't be engineers. Thus we need to value clear, simple code over condensed ninja moves. +While that's super cool, this isn't the project for that :) #### Force User Decisions To Best Practices -There are 1,000 ways to do something. However, something eventually becomes standard practice that everyone does. - Thus we pick one way of doing it and force everyone to do it this way. - A good example is accumulated gradients. - There are many ways to implement, we just pick one and force users to use that one. - A bad forced decision would be to make users use a specific library to do something. -When something becomes a best practice, we add it to the framework. This likely looks like code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it. +There are 1,000 ways to do something. However, eventually one popular solution becomes standard practice, and everyone follows. +We try to find the best way to solve a particular problem, and then force our users to use it for readability and simplicity. +A good example is accumulated gradients. +There are many different ways to implement it, we just pick one and force users to use it. +A bad forced decision would be to make users use a specific library to do something. + +When something becomes a best practice, we add it to the framework. This is usually something like bits of code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it. #### Simple External API -What makes sense to you may not make sense to others. Create an issue with an API change suggestion and validate that it makes sense for others. - Treat code changes how you treat a startup: validate that it's a needed feature, then add if it makes sense for many people. + +What makes sense to you may not make sense to others. When creating an issue with an API change suggestion, please validate that it makes sense for others. +Treat code changes the way you treat a startup: validate that it's a needed feature, then add if it makes sense for many people. #### Backward-compatible API -We all hate updating our deep learning packages because we don't want to refactor a bunch of stuff. In Lightning, we make sure every change we make which could break an API is backwards compatible with good deprecation warnings. + +We all hate updating our deep learning packages because we don't want to refactor a bunch of stuff. In Lightning, we make sure every change we make which could break an API is backward compatible with good deprecation warnings. **You shouldn't be afraid to upgrade Lightning :)** #### Gain User Trust -As a researcher you can't have any part of your code going wrong. So, make thorough tests that ensure an implementation of a new trick or subbtle change is correct. + +As a researcher, you can't have any part of your code going wrong. So, make thorough tests to ensure that every implementation of a new trick or subtle change is correct. #### Interoperability + Have a favorite feature from other libraries like fast.ai or transformers? Those should just work with lightning as well. Grab your favorite model or learning rate scheduler from your favorite library and run it in Lightning. --- ## Contribution Types -Currently looking for help implementing new features or adding bug fixes. -A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc...) we're in a good state there thanks to all the early contributors (even pre-beta release)! +We are always looking for help implementing new features or fixing bugs. + +A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc...) so we're in a good state there thanks to all the early contributors (even pre-beta release)! ### Bug Fixes: -1. Submit a github issue - try to decried what happen so other can reproduce it too. -2. Try to ix it or recommend a solution... + +1. If you find a bug please submit a github issue. + + - Make sure the title explains the issue. + - Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples. + - Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. + Note, that the sample code shall be minimal and if needed with publicly available data. + +2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach: + + - Convert your minimal code example to a unit/integration test with assert on expected results. + - Start by debugging the issue... You can run just this particular test in your IDE and draft a fix. + - Verify that your test case fails on the master branch and only passes with the fix applied. + 3. Submit a PR! +_**Note**, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :]_ ### New Features: -1. Submit a github issue - describe what is motivation of such feature (plus an use-case). -2. Let's discuss to agree on the feature scope. -3. Submit a PR! (with updated docs and tests 🙃). + +1. Submit a github issue - describe what is the motivation of such feature (adding the use case or an example is helpful). +2. Let's discuss to determine the feature scope. +3. Submit a PR! We recommend test driven approach to adding new features as well: + + - Write a test for the functionality you want to add. + - Write the functional code until the test passes. + +4. Add/update the relevant tests! + +- [This PR](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671) is a good example for adding a new metric, and [this one for a new logger](https://github.com/PyTorchLightning/pytorch-lightning/pull/2721). + +### Test cases: + +Want to keep Lightning healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features. + +Most of the tests in PyTorch Lightning train a trial MNIST model under various trainer conditions (ddp, ddp2+amp, etc...). The tests expect the model to perform to a reasonable degree of testing accuracy to pass. Want to add a new test case and not sure how? [Talk to us!](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A) --- ## Guidelines +### Developments scripts +To build the documentation locally, simply execute the following commands from project root (only for Unix): +- `make clean` cleans repo from temp/generated files +- `make docs` builds documentation under _docs/build/html_ +- `make test` runs all project's tests with coverage + +### Original code + +All added or edited code shall be the own original work of the particular contributor. +If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code's author. For example - `This code is inspired from http://...`. +In case you adding new dependencies, make sure that they are compatible with the actual PyTorch Lightning license (ie. dependencies should be _at least_ as permissive as the PyTorch Lightning license). + ### Coding Style -1. Use f-strings for output formation (except logging when we stay with lazy `logging.info("Hello %s!`, name). -2. Test the code with flake8, run locally PEP8 fixes: - ``` - autopep8 -v -r --max-line-length 120 --in-place . - ``` +1. Use f-strings for output formation (except logging when we stay with lazy `logging.info("Hello %s!", name)`. +2. You can use `pre-commit` to make sure your code style is correct. ### Documentation -We are using Sphinx with Napoleon extension. -Moreover we set Google style to follow with type convention. +We are using Sphinx with Napoleon extension. +Moreover, we set Google style to follow with type convention. - [Napoleon formatting with Google style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) - [ReStructured Text (reST)](https://docs.pylonsproject.org/projects/docs-style-guide/) - [Paragraph-level markup](https://www.sphinx-doc.org/en/1.5/markup/para.html) -See following short example of a sample function taking one position string and optional +See following short example of a sample function taking one position string and optional ```python from typing import Optional @@ -95,7 +141,7 @@ def my_func(param_a: int, param_b: Optional[float] = None) -> str: Args: param_a: first parameter param_b: second parameter - + Return: sum of both numbers @@ -110,77 +156,241 @@ def my_func(param_a: int, param_b: Optional[float] = None) -> str: return str(param_a + p) ``` -When updating the docs make sure to build them first locally and visually inspect the html files (in the browser) for -formatting errors. In certain cases, a missing blank line or a wrong indent can lead to a broken layout. -Run these commands +When updating the docs make sure to build them first locally and visually inspect the html files (in the browser) for +formatting errors. In certain cases, a missing blank line or a wrong indent can lead to a broken layout. +Run these commands + ```bash +pip install -r requirements/docs.txt cd docs -pip install -r requirements.txt make html ``` + and open `docs/build/html/index.html` in your browser. -When you send a PR the continuous integration will run tests and build the docs. You can access a preview of the html pages in the +Notes: + +- You need to have LaTeX installed for rendering math equations. You can for example install TeXLive by doing one of the following: + - on Ubuntu (Linux) run `apt-get install texlive` or otherwise follow the instructions on the TeXLive website + - use the [RTD docker image](https://hub.docker.com/r/readthedocs/build) +- with PL used class meta you need to use python 3.7 or higher + +When you send a PR the continuous integration will run tests and build the docs. You can access a preview of the html pages in the _Artifacts_ tab in CircleCI when you click on the task named _ci/circleci: Build-Docs_ at the bottom of the PR page. ### Testing -Test your work locally to speed up your work since so you can focus only in particular (failing) test-cases. - To setup a local development environment, install both local and test dependencies: +**Local:** Testing your work locally will help you speed up the process since it allows you to focus on particular (failing) test-cases. +To setup a local development environment, install both local and test dependencies: + ```bash -pip install -r requirements.txt -pip install -r tests/requirements-devel.txt -``` +python -m pip install ".[dev, examples]" +python -m pip install pre-commit +``` -You can run the full test-case in your terminal via this bash script: +You can run the full test-case in your terminal via this make script: ```bash -bash .run_local_tests.sh +make test ``` Note: if your computer does not have multi-GPU nor TPU these tests are skipped. -For convenience, you can use also your own CircleCI building which will be triggered with each commit. -This is useful if you do not test against all required dependencies version. -To do so, login to [CircleCI](https://app.circleci.com/) and enable your forked project in the dashboard. It will just work after that. +**GitHub Actions:** For convenience, you can also use your own GHActions building which will be triggered with each commit. +This is useful if you do not test against all required dependency versions. -### Pull Request +**Docker:** Another option is utilize the [pytorch lightning cuda base docker image](https://hub.docker.com/repository/docker/pytorchlightning/pytorch_lightning/tags?page=1&name=cuda). You can then run: + +```bash +python -m pytest pytorch_lightning tests pl_examples -v +``` -We welcome any useful contribution! For convinece here's a recommended workflow: +You can also run a single test as follows: + +```bash +python -m pytest -v tests/trainer/test_trainer_cli.py::test_default_args +``` + +### Pull Request -0. Think about what you want to do - fix a bug, repair docs, etc.  -1. Start your work locally (usually until you need our CI testing) - - create a branch and prepare your changes - - hint: do not work with your master directly, it may become complicated when you need to rebase - - hint: give your PR a good name! it will be useful later when you may work on multiple tasks/PRs -2. Create a "Draft PR" which is clearly marked which lets us know you don't need feedback yet. -3. When you feel like you are ready for integrating your work, turn your PR to "Ready for review". -4. Use tags in PR name for following cases: - - **[blocked by #]** if you work is depending on others changes - - **[wip]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime +We welcome any useful contribution! For your convenience here's a recommended workflow: + +0. Think about what you want to do - fix a bug, repair docs, etc. If you want to implement a new feature or enhance an existing one, start by opening a GitHub issue to explain the feature and the motivation. Members from core-contributors will take a look (it might take some time - we are often overloaded with issues!) and discuss it. Once an agreement was reached - start coding. +1. Start your work locally (usually until you need our CI testing). + - Create a branch and prepare your changes. + - Tip: do not work with your master directly, it may become complicated when you need to rebase. + - Tip: give your PR a good name! It will be useful later when you may work on multiple tasks/PRs. +2. Test your code! + - It is always good practice to start coding by creating a test case, verifying it breaks with current behaviour, and passes with your new changes. + - Make sure your new tests cover all different edge cases. + - Make sure all exceptions are handled. +3. Create a "Draft PR" which is clearly marked, to let us know you don't need feedback yet. +4. When you feel ready for integrating your work, mark your PR "Ready for review". + - Your code should be readable and follow the project's design principles. + - Make sure all tests are passing. + - Make sure you add a GitHub issue to your PR. +5. Use tags in PR name for following cases: + - **[blocked by #]** if you work is depending on others changes. + - **[wip]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. ### Question & Answer -1. **How can I help/contribute?** +#### How can I help/contribute? - All help is very welcome - reporting bug, solving issues and preparing bug fixes. To solve some issues you can start with label [good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) or chose something close to your domain with label [help wanted](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22). Before you start to implement anything check that the issue description that it is clear and self-assign the task to you (if it is not possible, just comment that you take it and we assign it to you...). +All types of contributions are welcome - reporting bugs, fixing documentation, adding test cases, solving issues, and preparing bug fixes. +To get started with code contributions, look for issues marked with the label [good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) or chose something close to your domain with the label [help wanted](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22). Before coding, make sure that the issue description is clear and comment on the issue so that we can assign it to you (or simply self-assign if you can). -2. **Is there a recommendation for branch names?** - - We do not rely on the name convention so far you are working with your own fork. Anyway it would be nice to follow this convention `/_` where the types are: `bugfix`, `feaure`, `docs`, `tests`, ... +#### Is there a recommendation for branch names? -3. **How to rebase my PR?** - - We recommend to create a PR in separate branch different from `master`, especially if you plan to submit several changes and do not want to wait until the fist one is resolved (we can work on them in parallel). Update your master with upstream (assuming you have already set [upstream](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/configuring-a-remote-for-a-fork)) - ```bash +We recommend you follow this convention `/_` where the types are: `bugfix`, `feature`, `docs`, or `tests` (but if you are using your own fork that's optional). + +#### How to rebase my PR? + +We recommend creating a PR in a separate branch other than `master`, especially if you plan to submit several changes and do not want to wait until the first one is resolved (we can work on them in parallel). + +First, make sure you have set [upstream](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/configuring-a-remote-for-a-fork) by running: + +```bash +git remote add upstream https://github.com/PyTorchLightning/pytorch-lightning.git +``` + +You'll know its set up right if you run `git remote -v` and see something similar to this: + +```bash +origin https://github.com/{YOUR_USERNAME}/pytorch-lightning.git (fetch) +origin https://github.com/{YOUR_USERNAME}/pytorch-lightning.git (push) +upstream https://github.com/PyTorchLightning/pytorch-lightning.git (fetch) +upstream https://github.com/PyTorchLightning/pytorch-lightning.git (push) +``` + +Checkout your feature branch and rebase it with upstream's master before pushing up your feature branch: + +```bash +git fetch --all --prune +git rebase upstream/master +# follow git instructions to resolve conflicts +git push -f +``` + +#### How to add new tests?** + +We are using [pytest](https://docs.pytest.org/en/stable/) in Pytorch Lightning. + +Here are tutorials: +* (recommended) [Visual Testing with pytest](https://www.youtube.com/playlist?list=PLCTHcU1KoD99Rim2tzg-IhYY2iu9FFvNo) from JetBrains on YouTube +* [Effective Python Testing With Pytest](https://realpython.com/pytest-python-testing/) article on realpython.com + +Here is the process to create a new test + +* 0. Optional: Follow tutorials ! +* 1. Find a file in tests/ which match what you want to test. If none, create one. +* 2. Use this template to get started ! +* 3. Use `BoringModel and derivates to test out your code`. + +```python +# TEST SHOULD BE IN YOUR FILE: tests/..../...py +# TEST CODE TEMPLATE + +# [OPTIONAL] pytest decorator +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_explain_what_is_being_tested(tmpdir): + """ + Test description about text reason to be + """ + + # os.environ["PL_DEV_DEBUG"] = '1' # [OPTIONAL] When activated, you can use internal trainer.dev_debugger + + class ExtendedModel(BoringModel): + ... + + model = ExtendedModel() + + # BoringModel is a functional model. You might want to set methods to None to test your behaviour + # Example: model.training_step_end = None + + trainer = Trainer( + default_root_dir=tmpdir, # will save everything within a tmpdir generated for this test + ... + ) + trainer.fit(model) + trainer.test() # [OPTIONAL] + + # assert the behaviour is correct. + assert ... +``` +run our/your test with +```bash +python -m pytest tests/..../...py::test_explain_what_is_being_tested --verbose --capture=no +``` + + +#### How to fix PR with mixed base and target branches? + +Sometimes you start your PR as a bug-fix but it turns out to be more of a feature (or the other way around). +Do not panic, the solution is very straightforward and quite simple. +All you need to do are these two steps in arbitrary order: + - Ask someone from Core to change the base/target branch to the correct one + - Rebase or cherry-pick your commits onto the correct base branch... + +Let's show how to deal with the git... +the sample case is moving a PR from `master` to `release/1.2-dev` assuming my branch name is `my-branch` +and the last true master commit is `ccc111` and your first commit is `mmm222`. + * **Cherry-picking** way + ```bash + git checkout my-branch + # create a local backup of your branch + git checkout -b my-branch-backup + # reset your branch to the correct base + git reset release/1.2-dev --hard + # ACTION: this step is much easier to do with IDE + # so open one and cherry-pick your last commits from `my-branch-backup` + # resolve all eventual conflict as the new base may contain different code + # when all done, push back to the open PR + git push -f + ``` + * **Rebasing way**, see more about [rebase onto usage](https://womanonrails.com/git-rebase-onto) + ```bash + git checkout my-branch + # rebase your commits on the correct branch + git rebase --onto release/1.2-dev ccc111 + # if there is no collision you shall see just success + # eventually you would need to resolve collision and in such case follow the instruction in terminal + # when all done, push back to the open PR + git push -f + ``` + + +### Bonus Workflow Tip + +If you don't want to remember all the commands above every time you want to push some code/setup a Lightning Dev environment on a new VM, you can set up bash aliases for some common commands. You can add these to one of your `~/.bashrc`, `~/.zshrc`, or `~/.bash_aliases` files. + +NOTE: Once you edit one of these files, remember to `source` it or restart your shell. (ex. `source ~/.bashrc` if you added these to your `~/.bashrc` file). + +```bash +plclone (){ + git clone https://github.com/{YOUR_USERNAME}/pytorch-lightning.git + cd pytorch-lightning + git remote add upstream https://github.com/PyTorchLightning/pytorch-lightning.git + # This is just here to print out info about your remote upstream/origin + git remote -v +} + +plfetch (){ git fetch --all --prune git checkout master git merge upstream/master - ``` - checkout your feature branch - ```bash - git checkout my-PR-branch +} + +# Rebase your branch with upstream's master +# plrebase +plrebase (){ + git checkout $@ git rebase master - # follow git instructions to resolve conflists - git push -f - ``` +} +``` + +Now, you can: + +- clone your fork and set up upstream by running `plclone` from your terminal +- fetch upstream and update your local master branch with it by running `plfetch` +- rebase your feature branch (after running `plfetch`) by running `plrebase your-branch-name` diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 8d0ab36957dccb..cef062516b0eb4 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -2,52 +2,43 @@ name: Bug report about: Create a report to help us improve title: '' -labels: bug, help wanted +labels: bug / fix, help wanted assignees: '' --- - - - ## 🐛 Bug -### To Reproduce +## Please reproduce using the BoringModel -Steps to reproduce the behavior: -1. Go to '...' -2. Run '....' -3. Scroll down to '....' -4. See error + - +### To Reproduce +Use following [**BoringModel**](https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3?usp=sharing) and post here -#### Code sample - + ### Expected behavior - + ### Environment -Please copy and paste the output from our -[environment collection script](https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py) -(or fill out the checklist below manually). +**Note**: `Bugs with code` are solved faster ! `Colab Notebook` should be made `public` ! + +* `IDE`: Please, use our python [bug_report_model.py](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py +) template. + +* `Colab Notebook`: Please copy and paste the output from our [environment collection script](https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py) (or fill out the checklist below manually). You can get the script and run it with: ``` wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py -# For security purposes, please check the contents of collect_env.py before running it. -python collect_env.py +# For security purposes, please check the contents of collect_env_details.py before running it. +python collect_env_details.py ``` - PyTorch Version (e.g., 1.0): diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000000000..9c6b05e68b2062 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Ask a Question + url: https://github.com/PyTorchLightning/pytorch-lightning/discussions/new + about: Ask and answer Lightning related questions + - name: 💬 Slack + url: https://app.slack.com/client/TR9DVT48M/CQXV8BRH9/thread/CQXV8BRH9-1591382895.254600 + about: Chat with our community diff --git a/.github/ISSUE_TEMPLATE/documentation.md b/.github/ISSUE_TEMPLATE/documentation.md index 2b249089657c8c..e78df92a18bab2 100644 --- a/.github/ISSUE_TEMPLATE/documentation.md +++ b/.github/ISSUE_TEMPLATE/documentation.md @@ -12,7 +12,7 @@ assignees: '' For typos and doc fixes, please go ahead and: 1. Create an issue. -2. Fix the typo. +2. Fix the typo. 3. Submit a PR. Thanks! diff --git a/.github/ISSUE_TEMPLATE/how-to-question.md b/.github/ISSUE_TEMPLATE/how-to-question.md deleted file mode 100644 index 87e0f3ec2efc4f..00000000000000 --- a/.github/ISSUE_TEMPLATE/how-to-question.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -name: How to question -about: Asking how-to questions -title: '' -labels: question -assignees: '' - ---- - -## ❓ Questions and Help - -### Before asking: -1. search the issues. -2. search the docs. - - - -#### What is your question? - -#### Code - - - -#### What have you tried? - -#### What's your environment? - - - OS: [e.g. iOS, Linux, Win] - - Packaging [e.g. pip, conda] - - Version [e.g. 0.5.2.1] diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0bda363228b1c0..28c035fcff8672 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,19 +1,35 @@ -# Before submitting +## What does this PR do? -- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) -- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), Pull Request section? -- [ ] Did you make sure to update the docs? -- [ ] Did you write any new necessary tests? -- [ ] If you made a notable change (that affects users), did you update the [CHANGELOG](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/CHANGELOG.md)? + +If we didn't discuss your PR in Github issues there's a high chance it will not be merged. -## What does this PR do? -Fixes # (issue). +The following links the related issue to the PR (https://docs.github.com/en/free-pro-team@latest/github/managing-your-work-on-github/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) +--> +Fixes # -## PR review -Anyone in the community is free to review the PR once the tests have passed. -If we didn't discuss your PR in Github issues there's a high chance it will not be merged. +## Before submitting +- [ ] Was this discussed/approved via a GitHub issue? (not for typos and docs) +- [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md), **Pull Request** section? +- [ ] Did you make sure your PR does only one thing, instead of bundling different changes together? +- [ ] Did you make sure to update the documentation with your changes? (if necessary) +- [ ] Did you write any new necessary tests? (not for typos and docs) +- [ ] Did you verify new and existing tests pass locally with your changes? +- [ ] Did you update the [CHANGELOG](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/CHANGELOG.md)? (not for typos, docs, test updates, or internal minor changes/refactorings) + + + +## PR review +Anyone in the community is free to review the PR once the tests have passed. +Before you start reviewing make sure you have read [Review guidelines](https://github.com/PyTorchLightning/pytorch-lightning/wiki/Review-guidelines). In short, see the following bullet-list: + + - [ ] Is this pull request ready for review? (if not, please submit in draft mode) + - [ ] Check that all items from **Before submitting** are resolved + - [ ] Make sure the title is self-explanatory and the description concisely explains the PR + - [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified ## Did you have fun? Make sure you had fun coding 🙃 diff --git a/.github/mergify.yml b/.github/mergify.yml new file mode 100644 index 00000000000000..8b321378a98eef --- /dev/null +++ b/.github/mergify.yml @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pull_request_rules: + + - name: warn on conflicts + conditions: + - conflict + - -draft # filter-out GH draft PRs + - -label="has conflicts" + actions: + # comment: + # message: This pull request is now in conflict... :( + label: + add: [ "has conflicts" ] + + - name: resolved conflicts + conditions: + - -conflict + - label="has conflicts" + - -draft # filter-out GH draft PRs + - -merged # not merged yet + - -closed + actions: + label: + remove: [ "has conflicts" ] + + #- name: update PR + # conditions: + # - -conflict + # - -draft # filter-out GH draft PRs + # - base=master # apply only on master + # - -title~=(?i)wip # skip all PR that title contains “WIP” (ignoring case) + # - "#approved-reviews-by>=3" # number of review approvals + # actions: + # update: {} + + - name: add core reviewer + conditions: + - -conflict # skip if conflict + - -draft # filter-out GH draft PRs + - label="0:] Ready-To-Go" + - "#approved-reviews-by<3" # number of review approvals + - "#review-requested<3" # number of requested reviews + actions: + request_reviews: + teams: + - "@PyTorchLightning/core-contributors" diff --git a/.github/prepare-nightly_version.py b/.github/prepare-nightly_version.py new file mode 100644 index 00000000000000..61710e73d69fd3 --- /dev/null +++ b/.github/prepare-nightly_version.py @@ -0,0 +1,18 @@ +import datetime +import os +import re + +# set paths +_PATH_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) +_PATH_INIT = os.path.join(_PATH_ROOT, 'pytorch_lightning', '__init__.py') + +# get today date +now = datetime.datetime.now() +now_date = now.strftime("%Y%m%d") + +print(f"prepare init '{_PATH_INIT}' - replace version by {now_date}") +with open(_PATH_INIT, 'r') as fp: + init = fp.read() +init = re.sub(r'__version__ = [\d\.\w\'"]+', f'__version__ = "{now_date}"', init) +with open(_PATH_INIT, 'w') as fp: + fp.write(init) diff --git a/.github/stale.yml b/.github/stale.yml index 365edce362c829..84049394d3aab5 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,19 +1,49 @@ # https://github.com/marketplace/stale +# https://github.com/probot/stale + +issues: + # Number of days of inactivity before an issue becomes stale + daysUntilStale: 30 + # Number of days of inactivity before a stale issue is closed + daysUntilClose: 7 + # Issues with these labels will never be considered stale + exemptLabels: + - Important + - Priority + # Comment to post when marking an issue as stale. Set to `false` to disable + markComment: > + This issue has been automatically marked as stale because it hasn't had any recent activity. + This issue will be closed in 7 days if no further activity occurs. + Thank you for your contributions, Pytorch Lightning Team! + # Comment to post when closing a stale issue. Set to `false` to disable + closeComment: false + +pulls: + # Number of days of inactivity before an pulls becomes stale + daysUntilStale: 14 + # Number of days of inactivity before a stale pull is closed + daysUntilClose: 5 + # Label to use when marking an issue as stale + staleLabel: "won't fix" + # Comment to post when marking an issue as stale. Set to `false` to disable + markComment: > + This pull request has been automatically marked as stale because it has not had recent activity. + It will be closed in 7 days if no further activity occurs. If you need further help see our docs: + https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request + or ask the assistance of a core contributor here or on Slack. + Thank you for your contributions. + # Comment to post when closing a stale issue. Set to `false` to disable + closeComment: > + This pull request is going to be closed. Please feel free to reopen it create a new from the actual master. -# Number of days of inactivity before an issue becomes stale -daysUntilStale: 60 -# Number of days of inactivity before a stale issue is closed -daysUntilClose: 9 -# Issues with these labels will never be considered stale -exemptLabels: - - pinned - - security # Label to use when marking an issue as stale -staleLabel: wontfix -# Comment to post when marking an issue as stale. Set to `false` to disable -markComment: > - This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. -# Comment to post when closing a stale issue. Set to `false` to disable -closeComment: false +staleLabel: "won't fix" +# Limit the number of actions per hour, from 1-30. Default is 30 +limitPerRun: 10 + +# Set to true to ignore issues in a project (defaults to false) +exemptProjects: true +# Set to true to ignore issues in a milestone (defaults to false) +exemptMilestones: true +# Set to true to ignore issues with an assignee (defaults to false) +exemptAssignees: true diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml deleted file mode 100644 index ac24dcee0a1e1b..00000000000000 --- a/.github/workflows/ci-testing.yml +++ /dev/null @@ -1,140 +0,0 @@ -name: CI testing - -# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows -on: - # Trigger the workflow on push or pull request, - # but only for the master branch - push: - branches: - - master - pull_request: - branches: - - master - -jobs: - build: - - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - # max-parallel: 6 - matrix: - os: [ubuntu-18.04, windows-2019, macOS-10.15] - python-version: [3.6, 3.7, 3.8] - requires: ['minimal', 'latest'] - exclude: - # excludes node 4 on macOS - - python-version: 3.8 - requires: 'minimal' - - # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 15 - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 - - name: Setup macOS - if: runner.os == 'macOS' - run: | - brew install libomp # https://github.com/pytorch/pytorch/issues/20030 - brew install openmpi # Horovod on macOS requires OpenMPI, Gloo not currently supported - - - name: Setup Windows - if: runner.os == 'windows' - run: | - python -c "lines = [line for line in open('requirements-extra.txt').readlines() if not line.startswith('horovod')] ; open('requirements-extra.txt', 'w').writelines(lines)" - - # TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved - - name: Setup Windows on Latest - if: runner.os == 'windows' && matrix.requires == 'latest' - run: | - python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)" - - - name: Set min. dependencies - if: matrix.requires == 'minimal' - run: | - python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)" - python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)" - - # Note: This uses an internal pip API and may not always work - # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - - name: Get pip cache - id: pip-cache - run: | - python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" - - - name: Cache pip - uses: actions/cache@v1 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-extra.txt') }} - restore-keys: | - ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip- - - - name: Install dependencies - run: | - # python -m pip install --upgrade --user pip - pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q - HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install -r ./tests/requirements-devel.txt -q - # pip install tox coverage - python --version - pip --version - pip list - shell: bash - - - name: Reinstall Horovod if necessary - if: runner.os != 'windows' && matrix.python-version != '3.8' - run: | - HOROVOD_BUILT=$(python -c "import horovod.torch; horovod.torch.nccl_built(); print('SUCCESS')") - if [[ $HOROVOD_BUILT != "SUCCESS" ]]; then - pip uninstall -y horovod - HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install --no-cache-dir $(grep "horovod" requirements-extra.txt) - fi - horovodrun --check-build - shell: bash - - - name: Cache datasets - uses: actions/cache@v1 - with: - path: tests/Datasets # This path is specific to Ubuntu - # Look to see if there is a cache hit for the corresponding requirements file - key: mnist-dataset - - - name: Tests - # env: - # TOXENV: py${{ matrix.python-version }} - run: | - # tox --sitepackages - # flake8 . - coverage run --source pytorch_lightning -m py.test pytorch_lightning tests -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - coverage report - - - name: Upload pytest test results - uses: actions/upload-artifact@master - with: - name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} - path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() - - - name: Package Setup - run: | - check-manifest - python setup.py check --metadata --strict - python setup.py sdist - twine check dist/* - #- name: Try install package - # if: ! startsWith(matrix.os, 'windows') - # run: | - # virtualenv vEnv ; source vEnv/bin/activate - # pip install --editable . ; cd .. & python -c "import pytorch_lightning ; print(pytorch_lightning.__version__)" - # deactivate ; rm -rf vEnv - - - name: Statistics - if: success() - run: | - coverage report \ No newline at end of file diff --git a/.github/workflows/ci_dockers.yml b/.github/workflows/ci_dockers.yml new file mode 100644 index 00000000000000..a22be81b398fb7 --- /dev/null +++ b/.github/workflows/ci_dockers.yml @@ -0,0 +1,149 @@ +name: CI build Docker +# https://www.docker.com/blog/first-docker-github-action-is-here +# https://github.com/docker/build-push-action +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] # include release branches like release/1.0.x + pull_request: + branches: [master, "release/*"] + paths: + - "dockers/**" + - "!dockers/README.md" + - "requirements/*.txt" + - "environment.yml" + - "requirements.txt" + - ".github/workflows/*docker*.yml" + - ".github/workflows/events-nightly.yml" + - "setup.py" + +jobs: + build-PL: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python_version: [3.6] + pytorch_version: [1.4, 1.7] + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Build PL Docker + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + file: dockers/release/Dockerfile + push: false + timeout-minutes: 50 + + build-XLA: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python_version: [3.7] + xla_version: [1.6, 1.7, "nightly"] + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Build XLA Docker + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + XLA_VERSION=${{ matrix.xla_version }} + file: dockers/base-xla/Dockerfile + push: false + timeout-minutes: 50 + + build-cuda: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + include: + # todo: see notes in Dockerfile + - python_version: 3.6 + pytorch_version: 1.4 + - python_version: 3.7 + pytorch_version: 1.6 + - python_version: 3.8 + pytorch_version: 1.7 + # - python_version: 3.9 + # pytorch_version: 1.7 + steps: + - name: Checkout + uses: actions/checkout@v2 + + # for PT 1.4 we need to use CUDA 10.1 + - run: | + cuda=$(python -c "print(10.2 if float(${{matrix.pytorch_version}}) >= 1.5 else 10.1)" 2>&1) + echo "::set-output name=CUDA::$cuda" + id: extend + + - name: Build CUDA Docker + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ steps.extend.outputs.CUDA }} + file: dockers/base-cuda/Dockerfile + push: false + timeout-minutes: 50 + + build-conda: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + include: + - python_version: 3.6 + pytorch_version: 1.4 + - python_version: 3.7 + pytorch_version: 1.7 + - python_version: 3.8 + pytorch_version: 1.8 + # - python_version: 3.9 + # pytorch_version: 1.8 + steps: + - name: Checkout + uses: actions/checkout@v2 + + # for PT 1.3 and 1.4 we need to use CUDA 10.1 + - run: | + cuda=$(python -c "print(10.2 if float(${{matrix.pytorch_version}}) > 1.4 else 10.1)" 2>&1) + echo "::set-output name=CUDA::$cuda" + id: extend + + - name: Build CUDA Docker + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ steps.extend.outputs.CUDA }} + file: dockers/base-conda/Dockerfile + push: false + timeout-minutes: 50 + + build-nvidia: + runs-on: ubuntu-20.04 + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Build NVIDIA Docker + uses: docker/build-push-action@v2 + with: + file: dockers/nvidia/Dockerfile + push: false + timeout-minutes: 50 diff --git a/.github/workflows/ci_pkg-install.yml b/.github/workflows/ci_pkg-install.yml new file mode 100644 index 00000000000000..b4557e5ed75aa6 --- /dev/null +++ b/.github/workflows/ci_pkg-install.yml @@ -0,0 +1,63 @@ +name: Install pkg + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + +jobs: + + pkg-install: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + # max-parallel: 6 + matrix: + # PyTorch 1.5 is failing on Win and bolts requires torchvision>=0.5 + os: [ubuntu-20.04, macOS-10.15 , windows-2019] # + python-version: [3.6, 3.9] + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Prepare env + run: | + pip install check-manifest "twine==3.2" setuptools wheel + + - name: Create package + run: | + check-manifest + # python setup.py check --metadata --strict + python setup.py sdist bdist_wheel + + - name: Check package + run: | + twine check dist/* + python setup.py clean + + - name: Setup Windows + if: runner.os == 'windows' + run: | + # this is just a hotfix because of Win cannot install it directly + pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + + - name: Install | Uninstall package - archive + run: | + # install as archive + pip install dist/*.tar.gz + cd .. + python -c "import pytorch_lightning as pl ; print(pl.__version__)" + pip uninstall -y pytorch-lightning + + - name: Install | Uninstall package - wheel + run: | + # install as wheel + pip install dist/*.whl + cd .. + python -c "import pytorch_lightning as pl ; print(pl.__version__)" + pip uninstall -y pytorch-lightning \ No newline at end of file diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml new file mode 100644 index 00000000000000..cf92c64f1feece --- /dev/null +++ b/.github/workflows/ci_test-base.yml @@ -0,0 +1,96 @@ +name: CI basic testing + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + +jobs: + doctest: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + # max-parallel: 6 + matrix: + os: [ubuntu-20.04, windows-2019, macOS-10.15] + python-version: [3.8] + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 20 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 + - name: Setup macOS + if: runner.os == 'macOS' + run: | + brew install libomp # https://github.com/pytorch/pytorch/issues/20030 + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Get pip cache + id: pip-cache + run: | + python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" + + - name: Cache pip + uses: actions/cache@v2 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}- + + - name: Install dependencies + run: | + python -m pip install --upgrade --user pip + pip install --requirement ./requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + pip install "pytest>6.0" "pytest-cov>2.10" --upgrade-strategy only-if-needed + python --version + pip --version + pip list + shell: bash + + - name: Cache datasets + uses: actions/cache@v2 + with: + path: Datasets # This path is specific to Ubuntu + # Look to see if there is a cache hit for the corresponding requirements file + key: PL-dataset + + - name: Test Package [only] + run: | + # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 + coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + + - name: Upload pytest test results + uses: actions/upload-artifact@v2 + with: + name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} + path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + if: failure() + + - name: Statistics + if: success() + run: | + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + if: always() + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: cpu,pytest + name: Base-coverage + fail_ci_if_error: false diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml new file mode 100644 index 00000000000000..da853bf623d1bf --- /dev/null +++ b/.github/workflows/ci_test-conda.yml @@ -0,0 +1,73 @@ +name: PyTorch & Conda + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [ master, "release/*" ] + pull_request: + branches: [master, "release/*"] + +jobs: + conda: + runs-on: ubuntu-20.04 + container: pytorchlightning/pytorch_lightning:base-conda-py${{ matrix.python-version }}-torch${{ matrix.pytorch-version }} + strategy: + fail-fast: false + matrix: + # os: [ubuntu-20.04] + python-version: [3.7] + pytorch-version: [1.4, 1.5, 1.6, 1.7, 1.8] + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 35 + steps: + - uses: actions/checkout@v2 + + - name: Update dependencies + run: | + conda info + conda list + # adjust versions according installed Torch version + python ./requirements/adjust_versions.py requirements/extra.txt + python ./requirements/adjust_versions.py requirements/examples.txt + pip install --requirement requirements/devel.txt --upgrade-strategy only-if-needed + pip list + + - name: Pull checkpoints from S3 + run: | + # enter legacy and update checkpoints from S3 + cd legacy + curl https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip --output checkpoints.zip + unzip -o checkpoints.zip + ls -l checkpoints/ + + - name: Tests + run: | + # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + shell: bash -l {0} + + - name: Upload pytest results + uses: actions/upload-artifact@v2 + with: + name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} + path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + if: failure() + + - name: Statistics + if: success() + run: | + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + if: always() + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: cpu,pytest,torch${{ matrix.pytorch-version }} + name: CPU-coverage + fail_ci_if_error: false diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml new file mode 100644 index 00000000000000..ec9e71c5b83b29 --- /dev/null +++ b/.github/workflows/ci_test-full.yml @@ -0,0 +1,191 @@ +name: CI complete testing + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + +jobs: + pytest: + + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-18.04, windows-2019, macOS-10.15] + python-version: [3.6, 3.7, 3.8, 3.9] + requires: ['minimal', 'latest'] + exclude: + - python-version: 3.9 + requires: 'minimal' + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + # TODO: the macOS is taking too long, probably caching did not work... + timeout-minutes: 40 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Update pip + run: | + # todo: unfreeze PIP after resolving minimal dependencies + pip install --quiet "pip==20.1" --upgrade --user # needed for get pip cacher folder + + # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 + - name: Setup macOS + if: runner.os == 'macOS' + run: | + brew install libomp # https://github.com/pytorch/pytorch/issues/20030 + brew install openmpi libuv # Horovod on macOS requires OpenMPI, Gloo not currently supported + + - name: Setup Windows + if: runner.os == 'windows' + run: | + # remove Horovod from requirements + fname = 'requirements/extra.txt' + lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] + open(fname, 'w').writelines(lines) + shell: python + + # todo: re-enable when allow testing py 3.9 with min config, atm some Hydra issues + #- name: Adjust minimal for Python 3.9 + # if: matrix.requires == 'minimal' && matrix.python-version == 3.9 + # run: | + # import re + # def _req(fname, ptn, ver): + # req = re.sub(ptn, ver, open(fname).read()) + # open(fname, 'w').write(req) + # + # _req('requirements.txt', r'torch>=[\d\.]+', 'torch>=1.8.0') + # _req('requirements/extra.txt', r'onnxruntime>=[\d\.]+', 'onnxruntime>=1.7.0') + # shell: python + + - name: Set min. dependencies + if: matrix.requires == 'minimal' + run: | + files = ( + 'requirements.txt', + 'requirements/extra.txt', + 'requirements/loggers.txt', + 'requirements/test.txt', + 'requirements/examples.txt', + ) + for fname in files: + req = open(fname).read().replace('>=', '==') + open(fname, 'w').write(req) + + # remove Fairscale from requirements + fname = 'requirements/extra.txt' + lines = [line for line in open(fname).readlines() if 'fairscale' not in line] + open(fname, 'w').writelines(lines) + shell: python + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Get pip cache dir + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" + + - name: pip cache + uses: actions/cache@v2 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements/extra.txt') }} + restore-keys: | + ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}- + + - name: Pull checkpoints from S3 + run: | + cd legacy + # wget is simpler but does not work on Windows + python -c "from urllib.request import urlretrieve ; urlretrieve('https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip', 'checkpoints.zip')" + ls -l . + unzip -o checkpoints.zip + ls -l checkpoints/ + + # todo: re-enable testing with Horovod + - name: py3.9 - temp skip Horovod + if: matrix.python-version == 3.9 + run: | + # pip uninstall -y horovod + python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" + + - name: Install dependencies + env: + # MAKEFLAGS: "-j2" + HOROVOD_BUILD_ARCH_FLAGS: "-mfma" + HOROVOD_WITHOUT_MXNET: 1 + HOROVOD_WITHOUT_TENSORFLOW: 1 + run: | + python --version + pip --version + # python -m pip install --upgrade --user pip + pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + # adjust versions according installed Torch version + python ./requirements/adjust_versions.py requirements/extra.txt + python ./requirements/adjust_versions.py requirements/examples.txt + pip install --requirement ./requirements/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + pip list + shell: bash + + - name: Reinstall Horovod if necessary + # todo: re-enable horovod on py3.9 when it will be supported + if: runner.os != 'windows' && matrix.python-version != 3.9 + env: + HOROVOD_BUILD_ARCH_FLAGS: "-mfma" + run: | + HOROVOD_BUILT=$(python -c "import horovod.torch; horovod.torch.nccl_built(); print('SUCCESS')" || true) + if [[ $HOROVOD_BUILT != "SUCCESS" ]]; then + pip uninstall -y horovod + echo $(grep "horovod" requirements/extra.txt) > requirements/horovod.txt + pip install --no-cache-dir -r requirements/horovod.txt + fi + horovodrun --check-build + shell: bash + + - name: Cache datasets + uses: actions/cache@v2 + with: + path: Datasets + key: pl-dataset + + - name: Tests + run: | + # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml + + - name: Examples + run: | + python -m pytest pl_examples -v --durations=10 + + - name: Upload pytest results + uses: actions/upload-artifact@v2 + with: + name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} + path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + if: failure() + + - name: Statistics + if: success() + run: | + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + if: always() + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: cpu,pytest,python${{ matrix.python-version }} + name: CPU-coverage + fail_ci_if_error: false diff --git a/.github/workflows/ci_test-mnodes.yml b/.github/workflows/ci_test-mnodes.yml new file mode 100644 index 00000000000000..a0525834131106 --- /dev/null +++ b/.github/workflows/ci_test-mnodes.yml @@ -0,0 +1,209 @@ +name: Multi Nodes GPU Tests + +# Workflow Steps: +# 1. Checkout Pytorch Lightning +# 2. Set up Python +# 3. Configure AWS Credentials +# 4. Install AWS Client +# 5. Get Current Sha Commit +# 6. Create Job Name +# 7. Update Test Configuration File +# 8. Install EKSClient +# 9. Create Gpu Node Pool +# 10. Check Current Node Pool | Current Elatic Pods +# 11. Apply Elastic +# 12. Wait 5 sec +# 13. Find ETCD TCP Address +# 14. Update Test Configuration File +# 15. Apply Multi Node Testing +# 16. Wait 120 secs +# 17. Listen to Jobs Logging +# 18. Statistics +# 19. Upload coverage results +# 20. Upload coverage to Codecov +# 21. Delete Group Node + +on: + push: + branches: + - never-ever-run- + #pull_request: + # types: [closed] + +env: + AWS_CLUSTER: pl-lightning-torchelastic + NODE_TYPE: g4dn.xlarge + NODES: 2 + NUM_GPUS: 1 + REGION: us-east-2 + MAX_CHECKS: 300 + CHECK_SPEEP: 2 + +jobs: + multi-nodes-gpu-testing: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python-version: [3.7] + pytorch-version: [1.5] + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 50 + + # runs only when merged happened. + # if: github.event.pull_request.merged == true + steps: + + - name: Checkout Pytorch Lightning + uses: actions/checkout@v2 + with: + repository: PyTorchLightning/pytorch-lightning + ref: ${{ github.event.base_ref }} + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Cache pip + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-multi-node + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + pip install awscli coverage + # todo + pip install git+https://${{ secrets.PL_GHOST_TOKEN }}@github.com/PyTorchLightning/lightning-dtrun.git@v0.0.3 -q --no-cache-dir + #pip install git+https://${{ secrets.PL_GHOST_TOKEN }}@github.com/PyTorchLightning/lightning-dtrun.git@mnodes -q --no-cache-dir + + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_KEY_ID }} + aws-region: us-east-2 + + - name: Get Current Sha Commit + id: vars + shell: bash + run: | + echo "::set-output name=SHA::$(git rev-parse --short HEAD)" + echo $PWD + + - name: Create Job Name + id: job + shell: bash + run: | + echo "::set-output name=ID::$(echo '${{ steps.vars.outputs.SHA }}-${{ matrix.python-version }}-${{ matrix.pytorch-version }}' | tr . - )" + echo "::set-output name=ID_NAME::$(echo 's-${{ steps.vars.outputs.SHA }}-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-e' | tr . - )" + + - name: Install EKSClient + run: | + curl --silent --location "https://github.com/weaveworks/eksctl/releases/latest/download/eksctl_$(uname -s)_amd64.tar.gz" | tar xz -C /tmp + sudo mv /tmp/eksctl /usr/local/bin + shell: bash + + - name: Create Gpu Node Pool + run: | + aws eks --region $REGION update-kubeconfig --name $AWS_CLUSTER + eksctl create nodegroup --name=${{ steps.job.outputs.ID }} --cluster=$AWS_CLUSTER --node-type=$NODE_TYPE --nodes=$NODES + # eksctl create nodegroup --name=${{ steps.job.outputs.ID }} --cluster=$AWS_CLUSTER --managed --spot --node-type=$NODE_TYPE --nodes=$NODES + shell: bash + + - name: Check Current Node Pool | Current Elatic Pods + run: | + eksctl get nodegroups --cluster $AWS_CLUSTER + kubectl get pods -n elastic-job + + - name: Apply Elastic + run: | + git clone https://github.com/pytorch/elastic.git + cd elastic/kubernetes + + kubectl apply -k config/default + + kubectl apply -f https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/master/nvidia-device-plugin.yml + kubectl apply -f https://raw.githubusercontent.com/pytorch/elastic/master/kubernetes/config/samples/etcd.yaml + + - name: Wait + # todo: this shall be dynamic + if: always() + shell: bash + run: | + sleep 5 + + - name: Find ETCD TCP Address + id: tcp + shell: bash + run: | + echo "::set-output name=TCP_ADDRESS::$(kubectl logs etcd -n elastic-job | grep -Eo '[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}:[0-9]{1,4}' | head -1)" + + - name: Update Test Config. File + run: | + import os + from dtrun.configs import prepare_multi_nodes_gpu_config + + assert os.path.isfile('./tests/mnode_tests.txt') + prepare_multi_nodes_gpu_config( + './.github/multi-nodes-gpu.yaml', + './tests/mnode_tests.txt', + sha="${{ steps.vars.outputs.SHA }}", + tcp_address="${{ steps.tcp.outputs.TCP_ADDRESS }}", + python_version="${{ matrix.python-version }}", + torch_version="${{ matrix.pytorch-version }}", + num_gpus=1, + ) + shell: python + + - name: Apply Multi Node Testing + run: | + # cat ./.github/multi-nodes-gpu.yaml + kubectl apply -f ./.github/multi-nodes-gpu.yaml + shell: bash + + - name: Wait + # todo: this shall be dynamic + if: always() + shell: bash + run: | + sleep 400 + + - name: Listen to Jobs Logging + shell: bash + run: | + # todo: Enable automatic checking. + # while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl logs ${{ steps.job.outputs.ID_NAME }}-worker-0 -n elastic-job | grep -i "error\|failed"; then status_code=1 && break; elif kubectl logs ${{ steps.job.outputs.ID }}-worker-0 -n elastic-job | grep "TEST END"; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \ + # echo "Done waiting. Job status code: $status_code" && \ + kubectl logs ${{ steps.job.outputs.ID_NAME }}-worker-0 -n elastic-job > /tmp/full_output.txt + if grep -q 'END_TOKEN' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/END_TOKEN/'; else mv /tmp/full_output.txt xx00; fi && \ + cat xx00 + + - name: Statistics + if: success() + run: | + cat ./xx01 | tail -n +2 | base64 --decode > /home/runner/work/pytorch-lightning/pytorch-lightning/.coverage + cd /home/runner/work/pytorch-lightning/pytorch-lightning && coverage report && coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + if: always() + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: multi-nodes,pytest + name: multi-nodes-coverage + fail_ci_if_error: false + + - name: Delete Group Node + if: always() + run: | + kubectl delete ElasticJob ${{ steps.job.outputs.ID_NAME }} -n elastic-job + eksctl delete nodegroup ${{ steps.job.outputs.ID }} --cluster=$AWS_CLUSTER diff --git a/.github/workflows/ci_test-tpu.yml b/.github/workflows/ci_test-tpu.yml new file mode 100644 index 00000000000000..62814b2f26f813 --- /dev/null +++ b/.github/workflows/ci_test-tpu.yml @@ -0,0 +1,144 @@ +name: TPU tests + +on: + push: + branches: [master, "release/*"] +# TODO: temporal disable TPU testing until we find way how to pass credentials to forked PRs +# pull_request: +# branches: +# - master + +env: + GKE_CLUSTER: lightning-cluster + GKE_ZONE: us-central1-a + IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image + MAX_CHECKS: 360 + CHECK_SPEEP: 5 + +jobs: + setup-build-publish-deploy: + name: tpu-testing-job + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7] + xla-version: [1.6, 1.7] + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 50 + + steps: + - name: Set IMAGETAG + run: echo "IMAGETAG=$(date +%s)_${{ matrix.python-version }}" >> $GITHUB_ENV + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: 1.14.x + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Checkout Pytorch Lightning + uses: actions/checkout@v2 + with: + repository: PyTorchLightning/pytorch-lightning + ref: ${{ github.event.pull_request.head.sha }} + + - name: Checkout ml-testing-accelerators + uses: actions/checkout@v2 + with: + repository: GoogleCloudPlatform/ml-testing-accelerators + path: ml-testing-accelerators + ref: 5e88ac24f631c27045e62f0e8d5dfcf34e425e25 + + - name: Setup gcloud CLI + uses: GoogleCloudPlatform/github-actions/setup-gcloud@master + with: + version: '290.0.1' + service_account_key: ${{ secrets.GKE_SA_KEY_BASE64 }} + project_id: ${{ secrets.GKE_PROJECT }} + export_default_credentials: true + + # Configure Docker to use the gcloud command-line tool as a credential helper for authentication. + - name: Configure Docker + run: |- + gcloud --quiet auth configure-docker + shell: bash + - name: Build and Push Docker Image + env: + PYTHON_VER: ${{ matrix.python-version }} + XLA_VER: ${{ matrix.xla-version }} + run: | + #cd dockers/tpu-tests + docker build --tag "$IMAGE:$IMAGETAG" -f ./dockers/tpu-tests/Dockerfile --build-arg "PYTHON_VERSION=$PYTHON_VER" --build-arg "PYTORCH_VERSION=$XLA_VER" . + docker push "$IMAGE:$IMAGETAG" + shell: bash + + - name: Install jsonnet + run: |- + go get github.com/google/go-jsonnet/cmd/jsonnet + shell: bash + # Get the GKE credentials so we can deploy to the cluster + # Use either zone or region depending on cluster setup. + - run: |- + gcloud container clusters get-credentials "$GKE_CLUSTER" --zone "$GKE_ZONE" + shell: bash + + - name: Deploy the job on the kubernetes cluster + env: + XLA_VER: ${{ matrix.xla-version }} + run: |- + python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; ttt = open(fname).read().replace('pytorch-VERSION', 'pytorch-$XLA_VER') ; open(fname, 'w').write(ttt)" + job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet --ext-str image=$IMAGE --ext-str image-tag=$IMAGETAG | kubectl create -f -) && \ + job_name=${job_name#job.batch/} && \ + job_name=${job_name% created} && \ + echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \ + i=0 && \ + # 60 checks spaced 30s apart = 900s total. + status_code=2 && \ + # Check on the job periodically. Set the status code depending on what + # happened to the job in Kubernetes. If we try MAX_CHECKS times and + # still the job hasn't finished, give up and return the starting + # non-zero status code. + printf "Waiting for job to finish: " && \ + while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \ + echo "Done waiting. Job status code: $status_code" && \ + pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \ + echo "GKE pod name: $pod_name" && \ + kubectl logs -f $pod_name --container=train > /tmp/full_output.txt + if grep -q '' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '//'; else mv /tmp/full_output.txt xx00; fi && \ + # First portion is the test logs. Print these to Github Action stdout. + cat xx00 && \ + echo "Done with log retrieval attempt." && \ + gcloud container images delete "$IMAGE:$IMAGETAG" --force-delete-tags && \ + echo "Status code: $status_code" + exit $status_code + shell: bash + + - name: Statistics + if: success() + run: | + mv ./xx01 coverage + # TODO: add human readable report + cat coverage + # sudo pip install pycobertura + # pycobertura show coverage.xml + + - name: Upload coverage results + uses: actions/upload-artifact@v2 + with: + name: coverage-TPU + path: coverage + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + if: always() + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage + flags: tpu,pytest + name: TPU-coverage + fail_ci_if_error: true diff --git a/.github/workflows/code-formatting.yml b/.github/workflows/code-formatting.yml new file mode 100644 index 00000000000000..bc03905ab2bbd8 --- /dev/null +++ b/.github/workflows/code-formatting.yml @@ -0,0 +1,76 @@ +name: "Check Code Format" + +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + +jobs: + #imports-check-isort: + # name: Check valid import formatting with isort + # runs-on: ubuntu-20.04 + # steps: + # - name: Checkout + # uses: actions/checkout@v2 + # - name: Set up Python 3.8 + # uses: actions/setup-python@v2 + # with: + # python-version: 3.8 + # - name: Install isort + # run: pip install isort==5.6.4 + # - name: Run isort + # run: isort --settings-path=./pyproject.toml --check-only --diff . + + #format-check-yapf: + # runs-on: ubuntu-20.04 + # steps: + # - uses: actions/checkout@master + # - uses: actions/setup-python@v2 + # with: + # python-version: 3.8 + # - name: Install dependencies + # run: | + # pip install --upgrade pip + # pip install yapf + # pip list + # shell: bash + # - name: yapf + # run: yapf --diff --parallel --recursive . + + python-pep8: + name: Python formatting PEP8 + runs-on: ubuntu-20.04 + + # Timeout: https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Install dependencies + run: | + pip install flake8 + + - name: Run checking + run: | + flake8 . + + python-typing-mypy: + name: Python typing check [mypy] + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@master + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install mypy + run: | + pip install mypy==0.790 + pip list + - name: mypy check + run: | + mypy diff --git a/.github/workflows/docker_builds.yml b/.github/workflows/docker_builds.yml deleted file mode 100644 index 736ff72460d749..00000000000000 --- a/.github/workflows/docker_builds.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: Publish Docker -on: - push: - branches: - - master - release: - types: - - created - -jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python_version: [3.6, 3.7, 3.8] - pytorch_version: [1.1, 1.2, 1.3, 1.4, 1.5] - steps: - - name: Extract Current Tag - if: contains(github.ref, 'refs/tags/') - id: get_version - run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} - - uses: actions/checkout@v2 - - name: Publish Releases to Docker - # only on releases - uses: elgohr/Publish-Docker-Github-Action@2.14 - if: contains(github.ref, 'refs/tags/') && !contains(${{ steps.get_version.outputs.VERSION }}, 'rc') %% !contains(${{ steps.get_version.outputs.VERSION }}, 'dev') - with: - name: pytorchlightning/pytorch_lightning - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - dockerfile: docker/Dockerfile - buildargs: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.VERSION }} - tags: "${{ steps.get_version.outputs.VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},stable-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" - - name: Publish Master - # publish master - uses: elgohr/Publish-Docker-Github-Action@2.14 - if: github.event_name == 'push' - with: - name: pytorchlightning/pytorch_lightning - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - dockerfile: docker/Dockerfile - buildargs: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.VERSION }} - tags: "nightly-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml deleted file mode 100644 index cd49a01e8e569e..00000000000000 --- a/.github/workflows/docs-check.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: "Docs check" -# https://github.com/marketplace/actions/sphinx-build - -on: -- pull_request - -jobs: - docs: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: ammaraskar/sphinx-action@master - with: - # git is requried to clone the docs theme - pre-build-command: "apt-get update -y && apt-get install -y git" - docs-folder: "docs/" - repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml new file mode 100644 index 00000000000000..4488c598c8ac75 --- /dev/null +++ b/.github/workflows/docs-checks.yml @@ -0,0 +1,109 @@ +name: "Docs check" +# https://github.com/marketplace/actions/sphinx-build + +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + +jobs: + sphinx-check: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - uses: ammaraskar/sphinx-action@master + with: + # git is required to clone the docs theme + # before custom requirement are resolved https://github.com/ammaraskar/sphinx-action/issues/16 + pre-build-command: "apt-get update -y && apt-get install -y git && pip install -r requirements/docs.txt" + docs-folder: "docs/" + repo-token: "${{ secrets.GITHUB_TOKEN }}" + + test-docs: + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Cache pip + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python --version + pip --version + # remove Horovod from requirements + python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" + # python -m pip install --upgrade --user pip + pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + pip install --requirement requirements/extra.txt + pip install --requirement requirements/loggers.txt + pip install --requirement requirements/docs.txt + pip list + shell: bash + + - name: Test Documentation + env: + SPHINX_MOCK_REQUIREMENTS: 0 + run: | + # First run the same pipeline as Read-The-Docs + apt-get update && sudo apt-get install -y cmake + cd docs + make doctest + make coverage + + make-docs: + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Cache pip + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python --version + pip --version + # pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + pip install --requirement requirements/docs.txt + # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux + sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures + pip list + shell: bash + + - name: Make Documentation + run: | + # First run the same pipeline as Read-The-Docs + cd docs + make clean + make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" + + - name: Upload built docs + uses: actions/upload-artifact@v2 + with: + name: docs-results-${{ github.sha }} + path: docs/build/html/ + # Use always() to always run this step to publish test results when there are test failures + if: success() diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml new file mode 100644 index 00000000000000..5ad4396a006f7a --- /dev/null +++ b/.github/workflows/events-nightly.yml @@ -0,0 +1,151 @@ +name: Nightly events + +# https://jasonet.co/posts/scheduled-actions/ +# https://github.community/t/distinct-job-for-each-schedule/17811/2 +on: + schedule: + - cron: "0 0 * * *" # At the end of every day + +# based on https://github.com/pypa/gh-action-pypi-publish +jobs: + pypi-release: + runs-on: ubuntu-20.04 + + steps: + # does nightly releases from feature branch + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Install dependencies + run: >- + python -m pip install --user --upgrade setuptools wheel + + - name: Build packages + run: | + python .github/prepare-nightly_version.py + python setup.py sdist bdist_wheel + ls -lh dist/ + + - name: Delay releasing + uses: juliangruber/sleep-action@v1 + with: + time: 5m + + # We do this, since failures on test.pypi aren't that bad + - name: Publish to Test PyPI + uses: pypa/gh-action-pypi-publish@v1.4.1 + with: + user: __token__ + password: ${{ secrets.test_pypi_password }} + repository_url: https://test.pypi.org/legacy/ + verbose: true + + docker-XLA: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python_version: [3.6, 3.7] + xla_version: [1.6, 1.7] # todo: , "nightly" + steps: + - name: Checkout + uses: actions/checkout@v2 + + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Publish XLA to Docker Hub + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + XLA_VERSION=${{ matrix.xla_version }} + file: dockers/base-xla/Dockerfile + push: true + tags: pytorchlightning/pytorch_lightning:base-xla-py${{ matrix.python_version }}-torch${{ matrix.xla_version }} + timeout-minutes: 55 + + docker-cuda-conda: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python_version: [3.6, 3.7, 3.8] + pytorch_version: [1.4, 1.5, 1.6, 1.7, 1.8] + + steps: + - name: Checkout + uses: actions/checkout@v2 + + # https://github.com/docker/setup-buildx-action + # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command + - uses: docker/setup-buildx-action@v1 + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + # for PT 1.3 and 1.4 we need to use CUDA 10.1 + - run: | + cuda=$(python -c "print(10.2 if float(${{matrix.pytorch_version}}) > 1.4 else 10.1)" 2>&1) + echo "::set-output name=CUDA::$cuda" + id: extend + + - name: Publish CUDA to Docker Hub + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ steps.extend.outputs.CUDA }} + file: dockers/base-cuda/Dockerfile + push: true + tags: pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} + timeout-minutes: 55 + + - name: Publish Conda to Docker Hub + # publish master/release + uses: docker/build-push-action@v2 + with: + build-args: | + PYTHON_VERSION=${{ matrix.python_version }} + PYTORCH_VERSION=${{ matrix.pytorch_version }} + CUDA_VERSION=${{ steps.extend.outputs.CUDA }} + file: dockers/base-conda/Dockerfile + push: true + tags: pytorchlightning/pytorch_lightning:base-conda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }} + timeout-minutes: 55 + +# docker-nvidia: +# runs-on: ubuntu-20.04 +# steps: +# - name: Checkout +# uses: actions/checkout@v2 +# +# # https://github.com/docker/setup-buildx-action +# # Set up Docker Buildx - to use cache-from and cache-to argument of buildx command +# - uses: docker/setup-buildx-action@v1 +# - name: Login to DockerHub +# uses: docker/login-action@v1 +# with: +# username: ${{ secrets.DOCKER_USERNAME }} +# password: ${{ secrets.DOCKER_PASSWORD }} +# +# - name: Publish NVIDIA to Docker Hub +# uses: docker/build-push-action@v2 +# with: +# file: dockers/nvidia/Dockerfile +# push: true +# tags: nvcr.io/pytorchlightning/pytorch_lightning:nvidia +# timeout-minutes: 55 diff --git a/.github/workflows/events-recurrent.yml b/.github/workflows/events-recurrent.yml new file mode 100644 index 00000000000000..4d647ca0fb44d7 --- /dev/null +++ b/.github/workflows/events-recurrent.yml @@ -0,0 +1,42 @@ +name: Recurrent events + +# https://jasonet.co/posts/scheduled-actions/ +# https://github.community/t/distinct-job-for-each-schedule/17811/2 +on: + push: + branches: [ master ] + schedule: + - cron: "*/20 * * * *" # At every 20 minutes + +env: + GKE_CLUSTER: lightning-cluster + GKE_ZONE: us-central1-a + +jobs: + tpu-cleanup: + name: TPU cleaning + runs-on: ubuntu-20.04 + + steps: + - name: Setup gcloud CLI + uses: GoogleCloudPlatform/github-actions/setup-gcloud@master + with: + version: '290.0.1' + service_account_key: ${{ secrets.GKE_SA_KEY_BASE64 }} + project_id: ${{ secrets.GKE_PROJECT }} + export_default_credentials: true + # Get the GKE credentials so we can deploy to the cluster; Use either zone or region depending on cluster setup. + - run: |- + gcloud container clusters get-credentials "$GKE_CLUSTER" --zone "$GKE_ZONE" + shell: bash + + - name: Clean all mong hanging jobs + run: | + # Match jobs whose age matches patterns like '1h' or '1d', i.e. any job + # that has been around longer than 1hr. First print all columns for + # matches, then execute the delete. + jobs_to_delete=$(kubectl get job | awk 'match($4,/[0-9]+[dh]/) {print $0}') + echo $jobs_to_delete + if [ ${#jobs_to_delete} -gt 1 ]; + then kubectl delete job $(kubectl get job | awk 'match($4,/[0-9]+[dh]/) {print $1}'); + fi \ No newline at end of file diff --git a/.github/workflows/greetings.yml b/.github/workflows/greetings.yml deleted file mode 100644 index 0b4be6726f2d52..00000000000000 --- a/.github/workflows/greetings.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Greetings -# https://github.com/marketplace/actions/first-interaction - -on: [issues] # pull_request - -jobs: - greeting: - runs-on: ubuntu-latest - steps: - - uses: actions/first-interaction@v1 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - issue-message: 'Hi! thanks for your contribution!, great first issue!' - pr-message: 'Hey thanks for the input! Please give us a bit of time to review it!' diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml deleted file mode 100644 index 041f67bc05b383..00000000000000 --- a/.github/workflows/pypi-release.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: PyPI Release - -# https://help.github.com/en/actions/reference/events-that-trigger-workflows -on: - # Trigger the workflow on push or pull request, - # but only for the master branch - push: - branches: - - master - release: - types: - - created - -# based on https://github.com/pypa/gh-action-pypi-publish - -jobs: - build: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v1 - with: - python-version: 3.7 - - - name: Install dependencies - run: >- - python -m pip install --user --upgrade setuptools wheel - - name: Build - run: >- - python setup.py sdist bdist_wheel - - # We do this, since failures on test.pypi aren't that bad - - name: Publish to Test PyPI - if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.test_pypi_password }} - repository_url: https://test.pypi.org/legacy/ - - - name: Publish distribution 📦 to PyPI - if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/rebase.yml b/.github/workflows/rebase.yml deleted file mode 100644 index 06d20652c6b5a4..00000000000000 --- a/.github/workflows/rebase.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: Automatic Rebase -# https://github.com/marketplace/actions/automatic-rebase - -on: - issue_comment: - types: [created] - -jobs: - rebase: - name: Rebase - if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/rebase') - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - name: Automatic Rebase - uses: cirrus-actions/rebase@1.2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml new file mode 100644 index 00000000000000..4838ac915bd47c --- /dev/null +++ b/.github/workflows/release-docker.yml @@ -0,0 +1,62 @@ +name: Publish Docker Releases +# https://www.docker.com/blog/first-docker-github-action-is-here +# https://github.com/docker/build-push-action +on: + push: + branches: [master, "release/*"] + release: + types: [created] + +jobs: + cuda-PL: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python_version: [3.6, 3.7, 3.8] + pytorch_version: [1.4, 1.5, 1.6, 1.7, 1.8] + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Get release version + if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' + id: get_version + run: echo "::set-output name=RELEASE_VERSION::$(echo ${GITHUB_REF##*/})" + + - name: Publish Releases to Docker + # only on releases + uses: docker/build-push-action@v1.1.0 + if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' + with: + repository: pytorchlightning/pytorch_lightning + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + dockerfile: dockers/release/Dockerfile + build_args: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} + tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},latest-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}" + timeout-minutes: 55 + +# nvidia-PL: +# runs-on: ubuntu-20.04 +# steps: +# - name: Checkout +# uses: actions/checkout@v2 +# +# - name: Get release version +# if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' +# id: get_version +# run: echo "::set-output name=RELEASE_VERSION::$(echo ${GITHUB_REF##*/})" +# +# - name: Publish Releases to Docker +# # only on releases +# uses: docker/build-push-action@v1.1.0 +# if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'release' +# with: +# repository: nvcr.io/pytorchlightning/pytorch_lightning +# username: ${{ secrets.DOCKER_USERNAME }} +# password: ${{ secrets.DOCKER_PASSWORD }} +# dockerfile: dockers/nvidia/Dockerfile +# build_args: LIGHTNING_VERSION=${{ steps.get_version.outputs.RELEASE_VERSION }} +# tags: "${{ steps.get_version.outputs.RELEASE_VERSION }}-nvidia" +# timeout-minutes: 55 diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml new file mode 100644 index 00000000000000..3a502d713971da --- /dev/null +++ b/.github/workflows/release-pypi.yml @@ -0,0 +1,141 @@ +name: PyPI Release + +# https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on push or pull request, but only for the master branch + push: + branches: [master, "release/*"] + release: + types: [created] + + +jobs: + # based on https://github.com/pypa/gh-action-pypi-publish + build-package: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Install dependencies + run: >- + python -m pip install --user --upgrade setuptools wheel + + - name: Build packages + run: | + python setup.py sdist bdist_wheel + ls -lh dist/ + + - uses: actions/upload-artifact@v2 + with: + name: pypi-packages-${{ github.sha }} + path: dist + + upload-package: + runs-on: ubuntu-20.04 + needs: build-package + steps: + - uses: actions/checkout@v2 + - uses: actions/download-artifact@v2 + with: + name: pypi-packages-${{ github.sha }} + path: dist + - run: ls -lh dist/ + + - name: Upload to release + if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' + uses: AButler/upload-release-assets@v2.0 + with: + files: 'dist/*' + repo-token: ${{ secrets.GITHUB_TOKEN }} + + publish-package: + runs-on: ubuntu-20.04 + needs: build-package + steps: + - uses: actions/checkout@v2 + - uses: actions/download-artifact@v2 + with: + name: pypi-packages-${{ github.sha }} + path: dist + - run: ls -lh dist/ + + - name: Delay releasing + if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' + uses: juliangruber/sleep-action@v1 + with: + time: 10m + + # We do this, since failures on test.pypi aren't that bad + - name: Publish to Test PyPI + if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' + uses: pypa/gh-action-pypi-publish@v1.4.1 + with: + user: __token__ + password: ${{ secrets.test_pypi_password }} + repository_url: https://test.pypi.org/legacy/ + verbose: true + + - name: Publish distribution 📦 to PyPI + if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' + uses: pypa/gh-action-pypi-publish@v1.4.1 + with: + user: __token__ + password: ${{ secrets.pypi_password }} + + create-legacy-ckpt: + runs-on: ubuntu-20.04 + needs: [build-package, publish-package] + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + # Note: This uses an internal pip API and may not always work + # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Cache pip + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + pip install awscli + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_KEY_ID }} + aws-region: us-east-1 + + - uses: actions/download-artifact@v2 + with: + name: pypi-packages-${{ github.sha }} + path: dist + + - name: Pull files from S3 + run: | + aws s3 cp --recursive s3://pl-public-data/legacy/checkpoints/ legacy/checkpoints/ # --acl public-read + ls -l legacy/checkpoints/ + + - name: Generate checkpoint + # if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' + run: | + ls -lh dist/ + pip install dist/*.whl + + pl_ver=$(python -c "import pytorch_lightning as pl ; print(pl.__version__)" 2>&1) + # generate checkpoint to this version + bash legacy/generate_checkpoints.sh $pl_ver + + - name: Push files to S3 + run: | + aws s3 sync legacy/checkpoints/ s3://pl-public-data/legacy/checkpoints/ + cd legacy + zip -r checkpoints.zip checkpoints + aws s3 cp checkpoints.zip s3://pl-public-data/legacy/ --acl public-read diff --git a/.gitignore b/.gitignore index cb8fd278c5c4f2..99939ff7fce0cd 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ test_tube_exp/ # Documentations docs/source/api docs/source/*.md +docs/source/generated +docs/source/*/generated # Byte-compiled / optimized / DLL files __pycache__/ @@ -26,12 +28,14 @@ timit_data/ # C extensions *.so +# PyCharm .idea/ # Distribution / packaging .Python ide_layouts/ build/ +_build/ develop-eggs/ dist/ downloads/ @@ -124,13 +128,33 @@ ENV/ # mypy .mypy_cache/ +# pytest +.pytest_cache/ # data .data/ -datasets/ +Datasets/ mnist/ +legacy/checkpoints/ # pl tests ml-runs/ +mlruns/ *.zip -pytorch\ lightning \ No newline at end of file +*.ckpt +pytorch\ lightning +test-reports/ +wandb +.forked/ +*.prof +*.tar.gz + +# dataset generated from bolts in examples. +cifar-10-batches-py +*.pt +# ctags +tags +data +MNIST +runs +*trace* diff --git a/.mergify.yml b/.mergify.yml deleted file mode 100644 index aa75c74dfa59dc..00000000000000 --- a/.mergify.yml +++ /dev/null @@ -1,49 +0,0 @@ -pull_request_rules: - - - name: Automatic merge on approval - conditions: - - base=master - # number of review approvals - - "#approved-reviews-by>=3" - # no waiting or assigned review - - "#review-requested=0" - # no requested chnages from any reviewer - - "#changes-requested-reviews-by=0" - # this serves as ALL check has to pass as we have actually 27 tests in total - - "#status-success>=30" - # this is just in case since we rely on GPU tests (note: redundand to the above) - - status-success=continuous-integration/drone/pr - # this is patter-like, unofrunatly serves as `any(...)` (note: redundand to the above) - - "status-success~=^ci/circleci:" - # no conflict with master branch - - -conflict - # was not closed yet - - -closed - actions: - delete_head_branch: {} - merge: - # https://doc.mergify.io/merge-action.html#strict-merge - # (on head branch) $ git merge --no-ff base - # (on head branch) # Wait for CI to go green - # (on head branch) # Squash all commits - # (on base branch) $ git merge --ff head - strict: true - method: squash - comment: - message: Great job! =) - - - name: warn on conflicts - conditions: - - conflict - actions: - comment: - message: This pull request is now in conflict... :( - - - name: add core reviewer - conditions: - # number of review approvals - - "#approved-reviews-by<3" - actions: - request_reviews: - teams: - - core-contributors diff --git a/.pep8speaks.yml b/.pep8speaks.yml index 2b92bded0dbf0a..e08b6fb7b55e19 100644 --- a/.pep8speaks.yml +++ b/.pep8speaks.yml @@ -4,16 +4,8 @@ scanner: diff_only: True # If False, the entire file touched by the Pull Request is scanned for errors. If True, only the diff is scanned. linter: pycodestyle # Other option is flake8 -pycodestyle: # Same as scanner.linter value. Other option is flake8 - max-line-length: 110 # Default is 79 in PEP 8 - ignore: # Errors and warnings to ignore - - W504 # line break after binary operator - - E402 # module level import not at top of file - - E731 # do not assign a lambda expression, use a def - - C406 # Unnecessary list literal - rewrite as a dict literal. - - E741 # ambiguous variable name - - F401 - - F841 +#pycodestyle: # this is dropped in favor od unified config +# see: https://github.com/OrkoHunter/pep8speaks/issues/95#issuecomment-470887715 no_blank_comment: True # If True, no comment is made on PR without any errors. descending_issues_order: False # If True, PEP 8 issues in message will be displayed in descending order of line numbers in the file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000000..45eca43de93ac8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +default_language_version: + python: python3.8 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + + - repo: https://github.com/PyCQA/isort + rev: 5.7.0 + hooks: + - id: isort + args: [--settings-path, ./pyproject.toml] + + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.30.0 + hooks: + - id: yapf + args: [--parallel, --in-place] diff --git a/.readthedocs.yml b/.readthedocs.yml index 312b50e6067c02..32a5a16248b91c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # .readthedocs.yml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details @@ -6,19 +20,23 @@ version: 2 # Build documentation in the docs/ directory with Sphinx +# reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx sphinx: configuration: docs/source/conf.py + fail_on_warning: true # Build documentation with MkDocs #mkdocs: # configuration: mkdocs.yml # Optionally build your docs in additional formats such as PDF and ePub -formats: all +formats: + - htmlzip + - pdf # Optionally set the version of Python and requirements required to build your docs python: version: 3.7 install: - - requirements: docs/requirements.txt + - requirements: requirements/docs.txt #- requirements: requirements.txt diff --git a/.run_local_tests.sh b/.run_local_tests.sh deleted file mode 100644 index 20fe84ff22fcf5..00000000000000 --- a/.run_local_tests.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash - -# install APEX, see https://github.com/NVIDIA/apex#linux -# to imitate SLURM set only single node -export SLURM_LOCALID=0 - -# use this to run tests -rm -rf _ckpt_* -rm -rf ./tests/save_dir* -rm -rf ./tests/mlruns_* -rm -rf ./tests/cometruns* -rm -rf ./tests/wandb* -rm -rf ./tests/tests/* -rm -rf ./lightning_logs -python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8 -python -m coverage report -m diff --git a/.update.sh b/.update.sh deleted file mode 100644 index 40fcc22d6b79ba..00000000000000 --- a/.update.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -version=$1 - -git commit -am "release v$version" -git tag $version -m "test_tube v$version" -git push --tags origin master - -# push to pypi -rm -rf ./dist/* -python3 setup.py sdist -twine upload dist/* - -# to update docs -# cd to root dir -# mkdocs gh-deploy - diff --git a/.yapfignore b/.yapfignore new file mode 100644 index 00000000000000..48c75600b1fa24 --- /dev/null +++ b/.yapfignore @@ -0,0 +1 @@ +.git/* diff --git a/CHANGELOG.md b/CHANGELOG.md index f3190f9d7353db..375d6ee060d0c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,41 +4,1510 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [unreleased] - YYYY-MM-DD + +## [UnReleased] - 2021-MM-DD ### Added -- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498)) -- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564)) +- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667)) + + +- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417)) + + +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + + +- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + + +- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) + + +- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948)) + + +- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) + + +- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673)) + + +- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) + + +- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) + + +- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) + + +- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) + + +- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + +- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) + + +- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) + + +- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) + + +- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) + + +- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595)) + + +- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) -- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). ### Changed -- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) +- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) + + +- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498)) -- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577)) ### Deprecated +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + + +- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), + [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), + [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), + [#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547), + [#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515), + [#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572), + [#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573), + [#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584), + [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), + [#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637), + [#6649](https://github.com/PyTorchLightning/pytorch-lightning/pull/6649), + [#6659](https://github.com/PyTorchLightning/pytorch-lightning/pull/6659), +) + + ### Removed +- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) + + +- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + +- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166)) + + +- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163)) + + +- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161)) + * from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve` + * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` + + +- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) + + +- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) + + +- Removed legacy references for magic keys in the `Result` object ([#6016](https://github.com/PyTorchLightning/pytorch-lightning/pull/6016)) + + +- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) + + +- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682)) + + +- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + ### Fixed -- Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654)) -- Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666)) +- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) -- Fixed Horovod distributed backend to set the `root_gpu` property ([#1669](https://github.com/PyTorchLightning/pytorch-lightning/pull/1669)) -- Fixed wandb logger `global_step` affects other loggers ([#1492](https://github.com/PyTorchLightning/pytorch-lightning/issues/1485)) +- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)) -- Fixed disabling progress bar on non-zero ranks using Horovod backend ([#1709](https://github.com/PyTorchLightning/pytorch-lightning/pull/1709)) -- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676)) +- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109)) + + +- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509)) + + +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) + + +- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) + + +- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) + + +- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) + + +- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) + + + +## [1.2.5] - 2021-03-23 + +### Changed + +- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) +- Refactored setup for typing friendly ([#6590](https://github.com/PyTorchLightning/pytorch-lightning/pull/6590)) + +### Fixed + +- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) +- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Fixed duplicate logs appearing in console when using the python logging module ([#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) +- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + + +## [1.2.4] - 2021-03-16 + +### Changed + +- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438)) + +### Fixed + +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) +- Fixed broadcast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410)) +- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) +- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460)) +- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) +- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) +- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + +## [1.2.3] - 2021-03-09 + +### Fixed + +- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073)) +- Fixed when `_stable_1d_sort` to work when `n >= N` ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177)) +- Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221)) +- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) +- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) +- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296)) +- Ensure we check deepspeed/sharded in multinode DDP ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) +- Check `LightningOptimizer` doesn't delete optimizer hooks ([#6305](https://github.com/PyTorchLightning/pytorch-lightning/pull/6305) +- Resolve memory leak for evaluation ([#6326](https://github.com/PyTorchLightning/pytorch-lightning/pull/6326) +- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330) +- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) + + +## [1.2.2] - 2021-03-02 + +### Added + +- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) + +### Changed + +- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147)) +- Changed default for DeepSpeed CPU Offload to False, due to prohibitively slow speeds at smaller scale ([#6262](https://github.com/PyTorchLightning/pytorch-lightning/pull/6262)) + +### Fixed + +- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075)) +- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) +- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216)) +- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147)) +- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931)) +- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) + + +## [1.2.1] - 2021-02-23 + +### Fixed + +- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) +- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) +- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) +- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + +## [1.2.0] - 2021-02-18 + +### Added + +- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)) +- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) +- Added support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) +- Added `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) +- Added `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) +- Added `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) +- Added `max_fpr` parameter to `auroc` metric for computing partial auroc metric ([#3790](https://github.com/PyTorchLightning/pytorch-lightning/pull/3790)) +- Added `StatScores` metric to compute the number of true positives, false positives, true negatives and false negatives ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) +- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) +- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347)) +- Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377)) +- Accelerator `all_gather` supports collection ([#5221](https://github.com/PyTorchLightning/pytorch-lightning/pull/5221)) +- Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056)) +- Added `MetricCollection` ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318)) +- Added `.clone()` method to metrics ([#4318](https://github.com/PyTorchLightning/pytorch-lightning/pull/4318)) +- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Support to tie weights after moving model to TPU via `on_post_move_to_device` hook +- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467)) +- The `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842)) +- Added `ModelPruning` Callback ([#5618](https://github.com/PyTorchLightning/pytorch-lightning/pull/5618), + [#5825](https://github.com/PyTorchLightning/pytorch-lightning/pull/5825), + [#6045](https://github.com/PyTorchLightning/pytorch-lightning/pull/6045)) +- Added `PyTorchProfiler` ([#5560](https://github.com/PyTorchLightning/pytorch-lightning/pull/5560)) +- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464)) +- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579)) +- Added `on_before_batch_transfer` and `on_after_batch_transfer` data hooks ([#3671](https://github.com/PyTorchLightning/pytorch-lightning/pull/3671)) +- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479)) +- Added `PredictLoop` object ([#5752](https://github.com/PyTorchLightning/pytorch-lightning/pull/5752)) +- Added `QuantizationAwareTraining` callback ([#5706](https://github.com/PyTorchLightning/pytorch-lightning/pull/5706), + [#6040](https://github.com/PyTorchLightning/pytorch-lightning/pull/6040)) +- Added `LightningModule.configure_callbacks` to enable the definition of model-specific callbacks ([#5621](https://github.com/PyTorchLightning/pytorch-lightning/pull/5621)) +- Added `dim` to `PSNR` metric for mean-squared-error reduction ([#5957](https://github.com/PyTorchLightning/pytorch-lightning/pull/5957)) +- Added promxial policy optimization template to pl_examples ([#5394](https://github.com/PyTorchLightning/pytorch-lightning/pull/5394)) +- Added `log_graph` to `CometLogger` ([#5295](https://github.com/PyTorchLightning/pytorch-lightning/pull/5295)) +- Added possibility for nested loaders ([#5404](https://github.com/PyTorchLightning/pytorch-lightning/pull/5404)) +- Added `sync_step` to Wandb logger ([#5351](https://github.com/PyTorchLightning/pytorch-lightning/pull/5351)) +- Added `StochasticWeightAveraging` callback ([#5640](https://github.com/PyTorchLightning/pytorch-lightning/pull/5640)) +- Added `LightningDataModule.from_datasets(...)` ([#5133](https://github.com/PyTorchLightning/pytorch-lightning/pull/5133)) +- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981)) +- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038)) +- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954), + [#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042)) + +### Changed + +- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) +- Changed `computer_vision_fine_tunning` example to use `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377)) +- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218)) +- Changed `iou` [func] to allow float input ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Metric `compute()` method will no longer automatically call `reset()` ([#5409](https://github.com/PyTorchLightning/pytorch-lightning/pull/5409/)) +- Set PyTorch 1.4 as min requirements, also for testing and examples `torchvision>=0.5` and `torchtext>=0.5` ([#5418](https://github.com/PyTorchLightning/pytorch-lightning/pull/5418)) +- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446)) +- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185)) +- Changed `ModelCheckpoint` version suffixes to start at 1 ([#5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008)) +- Progress bar metrics tensors are now converted to float ([#5692](https://github.com/PyTorchLightning/pytorch-lightning/pull/5692)) +- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) +- Extended support for purely iteration-based training ([#5726](https://github.com/PyTorchLightning/pytorch-lightning/pull/5726)) +- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) +- Forced `ModelCheckpoint` callbacks to run after all others to guarantee all states are saved to the checkpoint ([#5731](https://github.com/PyTorchLightning/pytorch-lightning/pull/5731)) +- Refactored Accelerators and Plugins ([#5743](https://github.com/PyTorchLightning/pytorch-lightning/pull/5743)) + * Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) + * Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714)) + * Precision Plugins ([#5718](https://github.com/PyTorchLightning/pytorch-lightning/pull/5718)) + * Added new Accelerators for CPU, GPU and TPU ([#5719](https://github.com/PyTorchLightning/pytorch-lightning/pull/5719)) + * Added Plugins for TPU training ([#5719](https://github.com/PyTorchLightning/pytorch-lightning/pull/5719)) + * Added RPC and Sharded plugins ([#5732](https://github.com/PyTorchLightning/pytorch-lightning/pull/5732)) + * Added missing `LightningModule`-wrapper logic to new plugins and accelerator ([#5734](https://github.com/PyTorchLightning/pytorch-lightning/pull/5734)) + * Moved device-specific teardown logic from training loop to accelerator ([#5973](https://github.com/PyTorchLightning/pytorch-lightning/pull/5973)) + * Moved accelerator_connector.py to the connectors subfolder ([#6033](https://github.com/PyTorchLightning/pytorch-lightning/pull/6033)) + * Trainer only references accelerator ([#6039](https://github.com/PyTorchLightning/pytorch-lightning/pull/6039)) + * Made parallel devices optional across all plugins ([#6051](https://github.com/PyTorchLightning/pytorch-lightning/pull/6051)) + * Cleaning ([#5948](https://github.com/PyTorchLightning/pytorch-lightning/pull/5948), + [#5949](https://github.com/PyTorchLightning/pytorch-lightning/pull/5949), + [#5950](https://github.com/PyTorchLightning/pytorch-lightning/pull/5950)) +- Enabled `self.log` in callbacks ([#5094](https://github.com/PyTorchLightning/pytorch-lightning/pull/5094)) +- Renamed xxx_AVAILABLE as protected ([#5082](https://github.com/PyTorchLightning/pytorch-lightning/pull/5082)) +- Unified module names in Utils ([#5199](https://github.com/PyTorchLightning/pytorch-lightning/pull/5199)) +- Separated utils: imports & enums ([#5256](https://github.com/PyTorchLightning/pytorch-lightning/pull/5256) + [#5874](https://github.com/PyTorchLightning/pytorch-lightning/pull/5874)) +- Refactor: clean trainer device & distributed getters ([#5300](https://github.com/PyTorchLightning/pytorch-lightning/pull/5300)) +- Simplified training phase as LightningEnum ([#5419](https://github.com/PyTorchLightning/pytorch-lightning/pull/5419)) +- Updated metrics to use LightningEnum ([#5689](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)) +- Changed the seq of `on_train_batch_end`, `on_batch_end` & `on_train_epoch_end`, `on_epoch_end hooks` ([#5688](https://github.com/PyTorchLightning/pytorch-lightning/pull/5688)) +- Refactored `setup_training` and remove `test_mode` ([#5388](https://github.com/PyTorchLightning/pytorch-lightning/pull/5388)) +- Disabled training with zero `num_training_batches` when insufficient `limit_train_batches` ([#5703](https://github.com/PyTorchLightning/pytorch-lightning/pull/5703)) +- Refactored `EpochResultStore` ([#5522](https://github.com/PyTorchLightning/pytorch-lightning/pull/5522)) +- Update `lr_finder` to check for attribute if not running `fast_dev_run` ([#5990](https://github.com/PyTorchLightning/pytorch-lightning/pull/5990)) +- LightningOptimizer manual optimizer is more flexible and expose `toggle_model` ([#5771](https://github.com/PyTorchLightning/pytorch-lightning/pull/5771)) +- `MlflowLogger` limit parameter value length to 250 char ([#5893](https://github.com/PyTorchLightning/pytorch-lightning/pull/5893)) +- Re-introduced fix for Hydra directory sync with multiple process ([#5993](https://github.com/PyTorchLightning/pytorch-lightning/pull/5993)) + +### Deprecated + +- Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) +- Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645)) +- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185)) +- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670)) +- Renamed utils modules ([#5199](https://github.com/PyTorchLightning/pytorch-lightning/pull/5199)) + * `argparse_utils` >> `argparse` + * `model_utils` >> `model_helpers` + * `warning_utils` >> `warnings` + * `xla_device_utils` >> `xla_device` +- Deprecated using `'val_loss'` to set the `ModelCheckpoint` monitor ([#6012](https://github.com/PyTorchLightning/pytorch-lightning/pull/6012)) +- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035)) +- Deprecated Trainer attribute `accelerator_backend` in favor of `accelerator` ([#6034](https://github.com/PyTorchLightning/pytorch-lightning/pull/6034)) + +### Removed + +- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321)) +- Removed deprecated `Fbeta`, `f1_score` and `fbeta_score` metrics ([#5322](https://github.com/PyTorchLightning/pytorch-lightning/pull/5322)) +- Removed deprecated `TrainResult` ([#5323](https://github.com/PyTorchLightning/pytorch-lightning/pull/5323)) +- Removed deprecated `EvalResult` ([#5633](https://github.com/PyTorchLightning/pytorch-lightning/pull/5633)) +- Removed `LoggerStages` ([#5673](https://github.com/PyTorchLightning/pytorch-lightning/pull/5673)) + +### Fixed + +- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297)) +- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861)) +- Fixed `DDPHPCAccelerator` hangs in DDP construction by calling `init_device` ([#5157](https://github.com/PyTorchLightning/pytorch-lightning/pull/5157)) +- Fixed `num_workers` for Windows example ([#5375](https://github.com/PyTorchLightning/pytorch-lightning/pull/5375)) +- Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619)) +- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745)) +- Fixed repeated `.fit()` calls ignore max_steps iteration bound ([#5936](https://github.com/PyTorchLightning/pytorch-lightning/pull/5936)) +- Fixed throwing `MisconfigurationError` on unknown mode ([#5255](https://github.com/PyTorchLightning/pytorch-lightning/pull/5255)) +- Resolve bug with Finetuning ([#5744](https://github.com/PyTorchLightning/pytorch-lightning/pull/5744)) +- Fixed `ModelCheckpoint` race condition in file existence check ([#5155](https://github.com/PyTorchLightning/pytorch-lightning/pull/5155)) +- Fixed some compatibility with PyTorch 1.8 ([#5864](https://github.com/PyTorchLightning/pytorch-lightning/pull/5864)) +- Fixed forward cache ([#5895](https://github.com/PyTorchLightning/pytorch-lightning/pull/5895)) +- Fixed recursive detach of tensors to CPU ([#6007](https://github.com/PyTorchLightning/pytorch-lightning/pull/6007)) +- Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923)) +- Fixed wrong `requires_grad` state after `return None` with multiple optimizers ([#5738](https://github.com/PyTorchLightning/pytorch-lightning/pull/5638)) +- Fixed add `on_epoch_end` hook at the end of `validation`, `test` epoch ([#5986](https://github.com/PyTorchLightning/pytorch-lightning/pull/5986)) +- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009)) +- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) +- Fixed `hparams.yaml` saved twice when using `TensorBoardLogger` ([#5953](https://github.com/PyTorchLightning/pytorch-lightning/pull/5953)) +- Fixed basic examples ([#5912](https://github.com/PyTorchLightning/pytorch-lightning/pull/5912), + [#5985](https://github.com/PyTorchLightning/pytorch-lightning/pull/5985)) +- Fixed `fairscale` compatible with PT 1.8 ([#5996](https://github.com/PyTorchLightning/pytorch-lightning/pull/5996)) +- Ensured `process_dataloader` is called when `tpu_cores > 1` to use Parallel DataLoader ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Attempted SLURM auto resume call when non-shell call fails ([#6002](https://github.com/PyTorchLightning/pytorch-lightning/pull/6002)) +- Fixed wrapping optimizers upon assignment ([#6006](https://github.com/PyTorchLightning/pytorch-lightning/pull/6006)) +- Fixed allowing hashing of metrics with lists in their state ([#5939](https://github.com/PyTorchLightning/pytorch-lightning/pull/5939)) + + +## [1.1.8] - 2021-02-08 + +### Fixed + +- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208)) +- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775)) + + +## [1.1.7] - 2021-02-03 + +### Fixed + +- Fixed `TensorBoardLogger` not closing `SummaryWriter` on `finalize` ([#5696](https://github.com/PyTorchLightning/pytorch-lightning/pull/5696)) +- Fixed filtering of pytorch "unsqueeze" warning when using DP ([#5622](https://github.com/PyTorchLightning/pytorch-lightning/pull/5622)) +- Fixed `num_classes` argument in F1 metric ([#5663](https://github.com/PyTorchLightning/pytorch-lightning/pull/5663)) +- Fixed `log_dir` property ([#5537](https://github.com/PyTorchLightning/pytorch-lightning/pull/5537)) +- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144)) +- Remove unnecessary intermediate layers in Dockerfiles ([#5697](https://github.com/PyTorchLightning/pytorch-lightning/pull/5697)) +- Fixed auto learning rate ordering ([#5638](https://github.com/PyTorchLightning/pytorch-lightning/pull/5638)) + + +## [1.1.6] - 2021-01-26 + +### Changed + +- Increased TPU check timeout from 20s to 100s ([#5598](https://github.com/PyTorchLightning/pytorch-lightning/pull/5598)) +- Ignored `step` param in Neptune logger's log_metric method ([#5510](https://github.com/PyTorchLightning/pytorch-lightning/pull/5510)) +- Pass batch outputs to `on_train_batch_end` instead of `epoch_end` outputs ([#4369](https://github.com/PyTorchLightning/pytorch-lightning/pull/4369)) + +### Fixed + +- Fixed `toggle_optimizer` to reset `requires_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) +- Fixed FileNotFoundError for best checkpoint when using DDP with Hydra ([#5629](https://github.com/PyTorchLightning/pytorch-lightning/pull/5629)) +- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620)) +- Fixed `Metric`'s `state_dict` not included when child modules ([#5614](https://github.com/PyTorchLightning/pytorch-lightning/pull/5614)) +- Fixed Neptune logger creating multiple experiments when GPUs > 1 ([#3256](https://github.com/PyTorchLightning/pytorch-lightning/pull/3256)) +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509)) +- Fixed tensor printing in `trainer.test()` ([#5138](https://github.com/PyTorchLightning/pytorch-lightning/pull/5138)) +- Fixed not using dataloader when `hparams` present ([#4559](https://github.com/PyTorchLightning/pytorch-lightning/pull/4559)) + + +## [1.1.5] - 2021-01-19 + +### Fixed + +- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) +- Fixed logging `on_train_batch_end` in a callback with multiple optimizers ([#5521](https://github.com/PyTorchLightning/pytorch-lightning/pull/5521)) +- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519)) +- Fixed `val_check_interval` with `fast_dev_run` ([#5540](https://github.com/PyTorchLightning/pytorch-lightning/pull/5540)) + + +## [1.1.4] - 2021-01-12 + +### Added + +- Add automatic optimization property setter to lightning module ([#5169](https://github.com/PyTorchLightning/pytorch-lightning/pull/5169)) + +### Changed + +- Changed deprecated `enable_pl_optimizer=True` ([#5244](https://github.com/PyTorchLightning/pytorch-lightning/pull/5244)) + +### Fixed + +- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) +- Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417)) +- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406)) +- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743)) +- Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505)) + +## [1.1.3] - 2021-01-05 + +### Added + +- Added a check for optimizer attached to `lr_scheduler` ([#5338](https://github.com/PyTorchLightning/pytorch-lightning/pull/5338)) +- Added support for passing non-existing filepaths to `resume_from_checkpoint` ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402)) + +### Changed + +- Skip restore from `resume_from_checkpoint` while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161)) +- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333)) +- Disabled checkpointing, earlystopping and logging with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277)) +- Distributed group defaults to `WORLD` if `None` ([#5125](https://github.com/PyTorchLightning/pytorch-lightning/pull/5125)) + +### Fixed + +- Fixed `trainer.test` returning non-test metrics ([#5214](https://github.com/PyTorchLightning/pytorch-lightning/pull/5214)) +- Fixed metric state reset ([#5273](https://github.com/PyTorchLightning/pytorch-lightning/pull/5273)) +- Fixed `--num-nodes` on `DDPSequentialPlugin` ([#5327](https://github.com/PyTorchLightning/pytorch-lightning/pull/5327)) +- Fixed invalid value for `weights_summary` ([#5296](https://github.com/PyTorchLightning/pytorch-lightning/pull/5296)) +- Fixed `Trainer.test` not using the latest `best_model_path` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161)) +- Fixed existence check for hparams not using underlying filesystem ([#5250](https://github.com/PyTorchLightning/pytorch-lightning/pull/5250)) +- Fixed `LightningOptimizer` AMP bug ([#5191](https://github.com/PyTorchLightning/pytorch-lightning/pull/5191)) +- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354)) + + +## [1.1.2] - 2020-12-23 + +### Added + +- Support number for logging with `sync_dist=True` ([#5080](https://github.com/PyTorchLightning/pytorch-lightning/pull/5080)) +- Added offset logging step when resuming for Wandb logger ([#5050](https://github.com/PyTorchLightning/pytorch-lightning/pull/5050)) + +### Removed + +- `enable_pl_optimizer=False` by default to temporarily fix AMP issues ([#5163](https://github.com/PyTorchLightning/pytorch-lightning/pull/5163)) + +### Fixed + +- Metric reduction with Logging ([#5150](https://github.com/PyTorchLightning/pytorch-lightning/pull/5150)) +- Remove nan loss in manual optimization ([#5121](https://github.com/PyTorchLightning/pytorch-lightning/pull/5121)) +- Un-balanced logging properly supported ([#5119](https://github.com/PyTorchLightning/pytorch-lightning/pull/5119)) +- Fix hanging in DDP HPC accelerators ([#5157](https://github.com/PyTorchLightning/pytorch-lightning/pull/5157)) +- Fix saved filename in `ModelCheckpoint` if it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861)) +- Fix reset `TensorRunningAccum` ([#5106](https://github.com/PyTorchLightning/pytorch-lightning/pull/5106)) +- Updated `DALIClassificationLoader` to not use deprecated arguments ([#4925](https://github.com/PyTorchLightning/pytorch-lightning/pull/4925)) +- Corrected call to `torch.no_grad` ([#5124](https://github.com/PyTorchLightning/pytorch-lightning/pull/5124)) + + +## [1.1.1] - 2020-12-15 + +### Added + +- Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning ([#4818](https://github.com/PyTorchLightning/pytorch-lightning/pull/4818)) + +### Changed + +- Simplify accelerator steps ([#5015](https://github.com/PyTorchLightning/pytorch-lightning/pull/5015)) +- Refactor load in checkpoint connector ([#4593](https://github.com/PyTorchLightning/pytorch-lightning/pull/4593)) +- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861)) + +### Removed +- Drop duplicate metrics ([#5014](https://github.com/PyTorchLightning/pytorch-lightning/pull/5014)) +- Remove beta arg from F1 class and functional ([#5076](https://github.com/PyTorchLightning/pytorch-lightning/pull/5076)) + +### Fixed + +- Fixed trainer by default `None` in `DDPAccelerator` ([#4915](https://github.com/PyTorchLightning/pytorch-lightning/pull/4915)) +- Fixed `LightningOptimizer` to expose optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095)) +- Do not warn when the `name` key is used in the `lr_scheduler` dict ([#5057](https://github.com/PyTorchLightning/pytorch-lightning/pull/5057)) +- Check if optimizer supports closure ([#4981](https://github.com/PyTorchLightning/pytorch-lightning/pull/4981)) +- Extend LightningOptimizer to exposure underlying Optimizer attributes + update doc ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095)) +- Add deprecated metric utility functions back to functional ( + [#5067](https://github.com/PyTorchLightning/pytorch-lightning/pull/5067), + [#5068](https://github.com/PyTorchLightning/pytorch-lightning/pull/5068)) +- Allow any input in `to_onnx` and `to_torchscript` ([#4378](https://github.com/PyTorchLightning/pytorch-lightning/pull/4378)) +- Do not warn when the name key is used in the `lr_scheduler` dict ([#5057](https://github.com/PyTorchLightning/pytorch-lightning/pull/5057)) +- Fixed `DDPHPCAccelerator` hangs in DDP construction by calling `init_device` ([#5157](https://github.com/PyTorchLightning/pytorch-lightning/pull/5157)) + + +## [1.1.0] - 2020-12-09 + +### Added + +- Added "monitor" key to saved `ModelCheckpoints` ([#4383](https://github.com/PyTorchLightning/pytorch-lightning/pull/4383)) +- Added `ConfusionMatrix` class interface ([#4348](https://github.com/PyTorchLightning/pytorch-lightning/pull/4348)) +- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) +- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) +- Added optimizer hooks in callbacks ([#4379](https://github.com/PyTorchLightning/pytorch-lightning/pull/4379)) +- Added option to log momentum ([#4384](https://github.com/PyTorchLightning/pytorch-lightning/pull/4384)) +- Added `current_score` to `ModelCheckpoint.on_save_checkpoint` ([#4721](https://github.com/PyTorchLightning/pytorch-lightning/pull/4721)) +- Added logging using `self.log` in train and evaluation for epoch end hooks ( + [#4552](https://github.com/PyTorchLightning/pytorch-lightning/pull/4552), + [#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495), + [#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439), + [#4684](https://github.com/PyTorchLightning/pytorch-lightning/pull/4684), + [#4913](https://github.com/PyTorchLightning/pytorch-lightning/pull/4913)) +- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675)) +- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647)) +- Added `prefix` argument in loggers ([#4557](https://github.com/PyTorchLightning/pytorch-lightning/pull/4557)) +- Added printing of total num of params, trainable and non-trainable params in ModelSummary ([#4521](https://github.com/PyTorchLightning/pytorch-lightning/pull/4521)) +- Added `PrecisionRecallCurve, ROC, AveragePrecision` class metric ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) +- Added custom `Apex` and `NativeAMP` as `Precision plugins` ([#4355](https://github.com/PyTorchLightning/pytorch-lightning/pull/4355)) +- Added `DALI MNIST` example ([#3721](https://github.com/PyTorchLightning/pytorch-lightning/pull/3721)) +- Added `sharded plugin` for DDP for multi-gpu training memory optimizations ( + [#4639](https://github.com/PyTorchLightning/pytorch-lightning/pull/4639), + [#4686](https://github.com/PyTorchLightning/pytorch-lightning/pull/4686), + [#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675), + [#4737](https://github.com/PyTorchLightning/pytorch-lightning/pull/4737), + [#4773](https://github.com/PyTorchLightning/pytorch-lightning/pull/4773)) +- Added `experiment_id` to the NeptuneLogger ([#3462](https://github.com/PyTorchLightning/pytorch-lightning/pull/3462)) +- Added `Pytorch Geometric` integration example with Lightning ([#4568](https://github.com/PyTorchLightning/pytorch-lightning/pull/4568)) +- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012)) +- Enabled `self.log` in most functions ([#4969](https://github.com/PyTorchLightning/pytorch-lightning/pull/4969)) +- Added changeable extension variable for `ModelCheckpoint` ([#4977](https://github.com/PyTorchLightning/pytorch-lightning/pull/4977)) + + +### Changed + +- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) +- `WandbLogger` does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) +- Changed `automatic_optimization` to be a model attribute ([#4602](https://github.com/PyTorchLightning/pytorch-lightning/pull/4602)) +- Changed `Simple Profiler` report to order by percentage time spent + num calls ([#4880](https://github.com/PyTorchLightning/pytorch-lightning/pull/4880)) +- Simplify optimization Logic ([#4984](https://github.com/PyTorchLightning/pytorch-lightning/pull/4984)) +- Classification metrics overhaul ([#4837](https://github.com/PyTorchLightning/pytorch-lightning/pull/4837)) +- Updated `fast_dev_run` to accept integer representing num_batches ([#4629](https://github.com/PyTorchLightning/pytorch-lightning/pull/4629)) +- Refactored optimizer ([#4658](https://github.com/PyTorchLightning/pytorch-lightning/pull/4658)) + + +### Deprecated + +- Deprecated `prefix` argument in `ModelCheckpoint` ([#4765](https://github.com/PyTorchLightning/pytorch-lightning/pull/4765)) +- Deprecated the old way of assigning hyper-parameters through `self.hparams = ...` ([#4813](https://github.com/PyTorchLightning/pytorch-lightning/pull/4813)) +- Deprecated `mode='auto'` from `ModelCheckpoint` and `EarlyStopping` ([#4695](https://github.com/PyTorchLightning/pytorch-lightning/pull/4695)) + +### Removed + +- Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004)) +- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) + +### Fixed + +- Added feature to move tensors to CPU before saving ([#4309](https://github.com/PyTorchLightning/pytorch-lightning/pull/4309)) +- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) +- Auto convert tensors to contiguous format when `gather_all` ([#4907](https://github.com/PyTorchLightning/pytorch-lightning/pull/4907)) +- Fixed `PYTHONPATH` for ddp test model ([#4528](https://github.com/PyTorchLightning/pytorch-lightning/pull/4528)) +- Fixed allowing logger to support indexing ([#4595](https://github.com/PyTorchLightning/pytorch-lightning/pull/4595)) +- Fixed DDP and manual_optimization ([#4976](https://github.com/PyTorchLightning/pytorch-lightning/pull/4976)) + + +## [1.0.8] - 2020-11-24 + +### Added + +- Added casting to python types for numpy scalars when logging `hparams` ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647)) +- Added warning when progress bar refresh rate is less than 20 on Google Colab to prevent crashing ([#4654](https://github.com/PyTorchLightning/pytorch-lightning/pull/4654)) +- Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656)) + +### Changed + +- Consistently use `step=trainer.global_step` in `LearningRateMonitor` independently of `logging_interval` ([#4376](https://github.com/PyTorchLightning/pytorch-lightning/pull/4376)) +- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/4685)) +- Renamed class metric `Fbeta` >> `FBeta` ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656)) +- Model summary: add 1 decimal place ([#4745](https://github.com/PyTorchLightning/pytorch-lightning/pull/4745)) +- Do not override `PYTHONWARNINGS` ([#4700](https://github.com/PyTorchLightning/pytorch-lightning/pull/4700)) +- Changed `init_ddp_connection` moved from `DDP` to `DDPPlugin` ([#4407](https://github.com/PyTorchLightning/pytorch-lightning/pull/4407)) + + +### Fixed + +- Fixed checkpoint `hparams` dict casting when `omegaconf` is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770)) +- Fixed incomplete progress bars when total batches not divisible by refresh rate ([#4577](https://github.com/PyTorchLightning/pytorch-lightning/pull/4577)) +- Updated SSIM metric ([#4566](https://github.com/PyTorchLightning/pytorch-lightning/pull/4566)) +- Fixed batch_arg_name - add `batch_arg_name` to all calls to `_adjust_batch_size`bug ([#4812](https://github.com/PyTorchLightning/pytorch-lightning/pull/4812)) +- Fixed `torchtext` data to GPU ([#4785](https://github.com/PyTorchLightning/pytorch-lightning/pull/4785)) +- Fixed a crash bug in MLFlow logger ([#4716](https://github.com/PyTorchLightning/pytorch-lightning/pull/4716)) + +## [1.0.7] - 2020-11-17 + +### Added + +- Added lambda closure to `manual_optimizer_step` ([#4618](https://github.com/PyTorchLightning/pytorch-lightning/pull/4618)) + +### Changed + +- Change Metrics `persistent` default mode to `False` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/4685)) +- LoggerConnector log_metrics will use `total_batch_idx` instead of `global_step` when logging on `training step` ([#4738](https://github.com/PyTorchLightning/pytorch-lightning/pull/4738)) + + +### Fixed + +- Prevent crash if `sync_dist=True` on CPU ([#4626](https://github.com/PyTorchLightning/pytorch-lightning/pull/4626)) +- Fixed average pbar Metrics ([#4534](https://github.com/PyTorchLightning/pytorch-lightning/pull/4534)) +- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608)) +- Allowing decorate model init with saving `hparams` inside ([#4662](https://github.com/PyTorchLightning/pytorch-lightning/pull/4662)) +- Fixed `split_idx` set by `LoggerConnector` in `on_trainer_init` to `Trainer` ([#4697](https://github.com/PyTorchLightning/pytorch-lightning/pull/4697)) + + +## [1.0.6] - 2020-11-11 + +### Added + +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485)) +- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) +- Added congratulations at the end of our notebooks ([#4555](https://github.com/PyTorchLightning/pytorch-lightning/pull/4555)) +- Added parameters `move_metrics_to_cpu` in Trainer to disable gpu leak ([#4592](https://github.com/PyTorchLightning/pytorch-lightning/pull/4592)) + + +### Changed + +- Changed `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458)) +- Unify SLURM/TorchElastic under backend plugin ([#4578](https://github.com/PyTorchLightning/pytorch-lightning/pull/4578), + [#4580](https://github.com/PyTorchLightning/pytorch-lightning/pull/4580), + [#4581](https://github.com/PyTorchLightning/pytorch-lightning/pull/4581), + [#4582](https://github.com/PyTorchLightning/pytorch-lightning/pull/4582), + [#4583](https://github.com/PyTorchLightning/pytorch-lightning/pull/4583)) + +### Fixed + +- Fixed feature-lack in `hpc_load` ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526)) +- Fixed metrics states being overridden in DDP mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) +- Fixed `lightning_getattr`, `lightning_hasattr` not finding the correct attributes in datamodule ([#4347](https://github.com/PyTorchLightning/pytorch-lightning/pull/4347)) +- Fixed automatic optimization AMP by `manual_optimization_step` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485)) +- Replace `MisconfigurationException` with warning in `ModelCheckpoint` Callback ([#4560](https://github.com/PyTorchLightning/pytorch-lightning/pull/4560)) +- Fixed logged keys in mlflow logger ([#4412](https://github.com/PyTorchLightning/pytorch-lightning/pull/4412)) +- Fixed `is_picklable` by catching `AttributeError` ([#4508](https://github.com/PyTorchLightning/pytorch-lightning/pull/4508)) +- Fixed multi test dataloaders dict `AttributeError` error ([#4480](https://github.com/PyTorchLightning/pytorch-lightning/pull/4480)) +- Fixed show progress bar only for `progress_rank 0` on `DDP_SLURM` ([#4437](https://github.com/PyTorchLightning/pytorch-lightning/pull/4437)) + +## [1.0.5] - 2020-11-03 + +### Added + +- Added PyTorch 1.7 Stable support ([#3821](https://github.com/PyTorchLightning/pytorch-lightning/pull/3821)) +- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) + +### Changed + +- W&B log in sync with `Trainer` step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) +- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Changed type checker with explicit cast of `ref_model` object ([#4457](https://github.com/PyTorchLightning/pytorch-lightning/pull/4457)) +- Changed `distributed_backend` -> `accelerator` ([#4429](https://github.com/PyTorchLightning/pytorch-lightning/pull/4429)) + +### Deprecated + +- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336)) + +### Fixed + +- Disable saving checkpoints if not trained ([#4372](https://github.com/PyTorchLightning/pytorch-lightning/pull/4372)) +- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) +- Disabled training when `limit_train_batches=0` ([#4371](https://github.com/PyTorchLightning/pytorch-lightning/pull/4371)) +- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313)) +- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428)) +- Fixed TorchScript trace method's data to device and docstring ([#4360](https://github.com/PyTorchLightning/pytorch-lightning/pull/4360)) +- Fixed CSV logger warning ([#4419](https://github.com/PyTorchLightning/pytorch-lightning/pull/4419)) +- Fixed skip DDP parameter sync ([#4301](https://github.com/PyTorchLightning/pytorch-lightning/pull/4301)) +- Fixed `WandbLogger` _sanitize_callable function ([#4422](https://github.com/PyTorchLightning/pytorch-lightning/pull/4422)) +- Fixed `AMP Native` `_unscale` gradient ([#4441](https://github.com/PyTorchLightning/pytorch-lightning/pull/4441)) + + +## [1.0.4] - 2020-10-27 + +### Added + +- Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213)) +- Added plugins docs and DDPPlugin to customize ddp across all accelerators ([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285)) +- Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586)) +- Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162)) +- Added autogenerated helptext to `Trainer.add_argparse_args` ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) +- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) +- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) +- Added `optimizer_closure` to `optimizer.step` when supported ([#4190](https://github.com/PyTorchLightning/pytorch-lightning/pull/4190)) +- Added unification of regression metrics ([#4166](https://github.com/PyTorchLightning/pytorch-lightning/pull/4166)) +- Added checkpoint load from Bytes ([#4314](https://github.com/PyTorchLightning/pytorch-lightning/pull/4314)) + +### Changed + +- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) +- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130)) +- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273)) +- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320)) + +### Deprecated + +- Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213)) +- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237)) +- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) + +### Fixed + +- Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297)) +- Fixed synchronization of best model path in `ddp_accelerator` ([#4323](https://github.com/PyTorchLightning/pytorch-lightning/pull/4323)) +- Fixed `WandbLogger` not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341)) +- Fixed `FBeta` computation ([#4183](https://github.com/PyTorchLightning/pytorch-lightning/pull/4183)) +- Fixed `accumulation across batches` has completed `before breaking training loop` ([#4278](https://github.com/PyTorchLightning/pytorch-lightning/pull/4278)) +- Fixed `ModelCheckpoint` don't increase current_epoch and global_step when not training ([#4291](https://github.com/PyTorchLightning/pytorch-lightning/pull/4291)) +- Fixed `COMET_EXPERIMENT_KEY` environment variable usage in comet logger ([#4230](https://github.com/PyTorchLightning/pytorch-lightning/pull/4230)) + +## [1.0.3] - 2020-10-20 + +### Added + +- Added persistent flag to `Metric.add_state` ([#4195](https://github.com/PyTorchLightning/pytorch-lightning/pull/4195)) + +### Changed + +- Used `checkpoint_connector.hpc_save` in SLURM ([#4217](https://github.com/PyTorchLightning/pytorch-lightning/pull/4217)) +- Moved base req. to root ([#4219](https://github.com/PyTorchLightning/pytorch-lightning/pull/4219)) + +### Fixed + +- Fixed `hparams` assign in init ([#4189](https://github.com/PyTorchLightning/pytorch-lightning/pull/4189)) +- Fixed overwrite check for model hooks ([#4010](https://github.com/PyTorchLightning/pytorch-lightning/pull/4010)) + + +## [1.0.2] - 2020-10-15 + +### Added + +- Added trace functionality to the function `to_torchscript` ([#4142](https://github.com/PyTorchLightning/pytorch-lightning/pull/4142)) + +### Changed + +- Called `on_load_checkpoint` before loading `state_dict` ([#4057](https://github.com/PyTorchLightning/pytorch-lightning/pull/4057)) + +### Removed + +- Removed duplicate metric vs step log for train loop ([#4173](https://github.com/PyTorchLightning/pytorch-lightning/pull/4173)) + +### Fixed + +- Fixed the `self.log` problem in `validation_step()` ([#4169](https://github.com/PyTorchLightning/pytorch-lightning/pull/4169)) +- Fixed `hparams` saving - save the state when `save_hyperparameters()` is called [in `__init__`] ([#4163](https://github.com/PyTorchLightning/pytorch-lightning/pull/4163)) +- Fixed runtime failure while exporting `hparams` to yaml ([#4158](https://github.com/PyTorchLightning/pytorch-lightning/pull/4158)) + + +## [1.0.1] - 2020-10-14 + +### Added + +- Added getstate/setstate method for torch.save serialization ([#4127](https://github.com/PyTorchLightning/pytorch-lightning/pull/4127)) + + +## [1.0.0] - 2020-10-13 + +### Added + +- Added Explained Variance Metric + metric fix ([#4013](https://github.com/PyTorchLightning/pytorch-lightning/pull/4013)) +- Added Metric <-> Lightning Module integration tests ([#4008](https://github.com/PyTorchLightning/pytorch-lightning/pull/4008)) +- Added parsing OS env vars in `Trainer` ([#4022](https://github.com/PyTorchLightning/pytorch-lightning/pull/4022)) +- Added classification metrics ([#4043](https://github.com/PyTorchLightning/pytorch-lightning/pull/4043)) +- Updated explained variance metric ([#4024](https://github.com/PyTorchLightning/pytorch-lightning/pull/4024)) +- Enabled plugins ([#4041](https://github.com/PyTorchLightning/pytorch-lightning/pull/4041)) +- Enabled custom clusters ([#4048](https://github.com/PyTorchLightning/pytorch-lightning/pull/4048)) +- Enabled passing in custom accelerators ([#4050](https://github.com/PyTorchLightning/pytorch-lightning/pull/4050)) +- Added `LightningModule.toggle_optimizer` ([#4058](https://github.com/PyTorchLightning/pytorch-lightning/pull/4058)) +- Added `LightningModule.manual_backward` ([#4063](https://github.com/PyTorchLightning/pytorch-lightning/pull/4063)) +- Added `output` argument to `*_batch_end` hooks ([#3965](https://github.com/PyTorchLightning/pytorch-lightning/pull/3965), + [#3966](https://github.com/PyTorchLightning/pytorch-lightning/pull/3966)) +- Added `output` argument to `*_epoch_end` hooks ([#3967](https://github.com/PyTorchLightning/pytorch-lightning/pull/3967)) + +### Changed + +- Integrated metrics API with self.log ([#3961](https://github.com/PyTorchLightning/pytorch-lightning/pull/3961)) +- Decoupled Apex ([#4052](https://github.com/PyTorchLightning/pytorch-lightning/pull/4052), + [#4054](https://github.com/PyTorchLightning/pytorch-lightning/pull/4054), + [#4055](https://github.com/PyTorchLightning/pytorch-lightning/pull/4055), + [#4056](https://github.com/PyTorchLightning/pytorch-lightning/pull/4056), + [#4058](https://github.com/PyTorchLightning/pytorch-lightning/pull/4058), + [#4060](https://github.com/PyTorchLightning/pytorch-lightning/pull/4060), + [#4061](https://github.com/PyTorchLightning/pytorch-lightning/pull/4061), + [#4062](https://github.com/PyTorchLightning/pytorch-lightning/pull/4062), + [#4063](https://github.com/PyTorchLightning/pytorch-lightning/pull/4063), + [#4064](https://github.com/PyTorchLightning/pytorch-lightning/pull/4064), + [#4065](https://github.com/PyTorchLightning/pytorch-lightning/pull/4065)) +- Renamed all backends to `Accelerator` ([#4066](https://github.com/PyTorchLightning/pytorch-lightning/pull/4066)) +- Enabled manual returns ([#4089](https://github.com/PyTorchLightning/pytorch-lightning/pull/4089)) + +### Removed + +- Removed support for EvalResult and TrainResult ([#3968](https://github.com/PyTorchLightning/pytorch-lightning/pull/3968)) +- Removed deprecated trainer flags: `overfit_pct`, `log_save_interval`, `row_log_interval` ([#3969](https://github.com/PyTorchLightning/pytorch-lightning/pull/3969)) +- Removed deprecated early_stop_callback ([#3982](https://github.com/PyTorchLightning/pytorch-lightning/pull/3982)) +- Removed deprecated model hooks ([#3980](https://github.com/PyTorchLightning/pytorch-lightning/pull/3980)) +- Removed deprecated callbacks ([#3979](https://github.com/PyTorchLightning/pytorch-lightning/pull/3979)) +- Removed `trainer` argument in `LightningModule.backward` [#4056](https://github.com/PyTorchLightning/pytorch-lightning/pull/4056)) + +### Fixed + +- Fixed `current_epoch` property update to reflect true epoch number inside `LightningDataModule`, when `reload_dataloaders_every_epoch=True`. ([#3974](https://github.com/PyTorchLightning/pytorch-lightning/pull/3974)) +- Fixed to print scaler value in progress bar ([#4053](https://github.com/PyTorchLightning/pytorch-lightning/pull/4053)) +- Fixed mismatch between docstring and code regarding when `on_load_checkpoint` hook is called ([#3996](https://github.com/PyTorchLightning/pytorch-lightning/pull/3996)) + + +## [0.10.0] - 2020-10-07 + +### Added + +- Added new Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868), [#3921](https://github.com/PyTorchLightning/pytorch-lightning/pull/3921)) +- Enable PyTorch 1.7 compatibility ([#3541](https://github.com/PyTorchLightning/pytorch-lightning/pull/3541)) +- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258)) +- Added warning when dropping unpicklable `hparams` ([#2874](https://github.com/PyTorchLightning/pytorch-lightning/pull/2874)) +- Added EMB similarity ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349)) +- Added `ModelCheckpoint.to_yaml` method ([#3048](https://github.com/PyTorchLightning/pytorch-lightning/pull/3048)) +- Allow `ModelCheckpoint` monitor to be `None`, meaning it will always save ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) +- Disabled optimizers setup during testing ([#3059](https://github.com/PyTorchLightning/pytorch-lightning/pull/3059)) +- Added support for datamodules to save and load checkpoints when training ([#3563](https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)) +- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425)) +- Added gradient clip test for native AMP ([#3754](https://github.com/PyTorchLightning/pytorch-lightning/pull/3754)) +- Added dist lib to enable syncing anything across devices ([#3762](https://github.com/PyTorchLightning/pytorch-lightning/pull/3762)) +- Added `broadcast` to `TPUBackend` ([#3814](https://github.com/PyTorchLightning/pytorch-lightning/pull/3814)) +- Added `XLADeviceUtils` class to check XLA device type ([#3274](https://github.com/PyTorchLightning/pytorch-lightning/pull/3274)) + +### Changed + +- Refactored accelerator backends: + * moved TPU `xxx_step` to backend ([#3118](https://github.com/PyTorchLightning/pytorch-lightning/pull/3118)) + * refactored DDP backend `forward` ([#3119](https://github.com/PyTorchLightning/pytorch-lightning/pull/3119)) + * refactored GPU backend `__step` ([#3120](https://github.com/PyTorchLightning/pytorch-lightning/pull/3120)) + * refactored Horovod backend ([#3121](https://github.com/PyTorchLightning/pytorch-lightning/pull/3121), + [#3122](https://github.com/PyTorchLightning/pytorch-lightning/pull/3122)) + * remove obscure forward call in eval + CPU backend `___step` ([#3123](https://github.com/PyTorchLightning/pytorch-lightning/pull/3123)) + * reduced all simplified forward ([#3126](https://github.com/PyTorchLightning/pytorch-lightning/pull/3126)) + * added hook base method ([#3127](https://github.com/PyTorchLightning/pytorch-lightning/pull/3127)) + * refactor eval loop to use hooks - use `test_mode` for if so we can split later ([#3129](https://github.com/PyTorchLightning/pytorch-lightning/pull/3129)) + * moved `___step_end` hooks ([#3130](https://github.com/PyTorchLightning/pytorch-lightning/pull/3130)) + * training forward refactor ([#3134](https://github.com/PyTorchLightning/pytorch-lightning/pull/3134)) + * training AMP scaling refactor ([#3135](https://github.com/PyTorchLightning/pytorch-lightning/pull/3135)) + * eval step scaling factor ([#3136](https://github.com/PyTorchLightning/pytorch-lightning/pull/3136)) + * add eval loop object to streamline eval loop ([#3138](https://github.com/PyTorchLightning/pytorch-lightning/pull/3138)) + * refactored dataloader process hook ([#3139](https://github.com/PyTorchLightning/pytorch-lightning/pull/3139)) + * refactored inner eval loop ([#3141](https://github.com/PyTorchLightning/pytorch-lightning/pull/3141)) + * final inner eval loop hooks ([#3154](https://github.com/PyTorchLightning/pytorch-lightning/pull/3154)) + * clean up hooks in `run_evaluation` ([#3156](https://github.com/PyTorchLightning/pytorch-lightning/pull/3156)) + * clean up data reset ([#3161](https://github.com/PyTorchLightning/pytorch-lightning/pull/3161)) + * expand eval loop out ([#3165](https://github.com/PyTorchLightning/pytorch-lightning/pull/3165)) + * moved hooks around in eval loop ([#3195](https://github.com/PyTorchLightning/pytorch-lightning/pull/3195)) + * remove `_evaluate` fx ([#3197](https://github.com/PyTorchLightning/pytorch-lightning/pull/3197)) + * `Trainer.fit` hook clean up ([#3198](https://github.com/PyTorchLightning/pytorch-lightning/pull/3198)) + * DDPs train hooks ([#3203](https://github.com/PyTorchLightning/pytorch-lightning/pull/3203)) + * refactor DDP backend ([#3204](https://github.com/PyTorchLightning/pytorch-lightning/pull/3204), + [#3207](https://github.com/PyTorchLightning/pytorch-lightning/pull/3207), + [#3208](https://github.com/PyTorchLightning/pytorch-lightning/pull/3208), + [#3209](https://github.com/PyTorchLightning/pytorch-lightning/pull/3209), + [#3210](https://github.com/PyTorchLightning/pytorch-lightning/pull/3210)) + * reduced accelerator selection ([#3211](https://github.com/PyTorchLightning/pytorch-lightning/pull/3211)) + * group prepare data hook ([#3212](https://github.com/PyTorchLightning/pytorch-lightning/pull/3212)) + * added data connector ([#3285](https://github.com/PyTorchLightning/pytorch-lightning/pull/3285)) + * modular is_overridden ([#3290](https://github.com/PyTorchLightning/pytorch-lightning/pull/3290)) + * adding `Trainer.tune()` ([#3293](https://github.com/PyTorchLightning/pytorch-lightning/pull/3293)) + * move `run_pretrain_routine` -> `setup_training` ([#3294](https://github.com/PyTorchLightning/pytorch-lightning/pull/3294)) + * move train outside of setup training ([#3297](https://github.com/PyTorchLightning/pytorch-lightning/pull/3297)) + * move `prepare_data` to data connector ([#3307](https://github.com/PyTorchLightning/pytorch-lightning/pull/3307)) + * moved accelerator router ([#3309](https://github.com/PyTorchLightning/pytorch-lightning/pull/3309)) + * train loop refactor - moving train loop to own object ([#3310](https://github.com/PyTorchLightning/pytorch-lightning/pull/3310), + [#3312](https://github.com/PyTorchLightning/pytorch-lightning/pull/3312), + [#3313](https://github.com/PyTorchLightning/pytorch-lightning/pull/3313), + [#3314](https://github.com/PyTorchLightning/pytorch-lightning/pull/3314)) + * duplicate data interface definition up into DataHooks class ([#3344](https://github.com/PyTorchLightning/pytorch-lightning/pull/3344)) + * inner train loop ([#3359](https://github.com/PyTorchLightning/pytorch-lightning/pull/3359), + [#3361](https://github.com/PyTorchLightning/pytorch-lightning/pull/3361), + [#3362](https://github.com/PyTorchLightning/pytorch-lightning/pull/3362), + [#3363](https://github.com/PyTorchLightning/pytorch-lightning/pull/3363), + [#3365](https://github.com/PyTorchLightning/pytorch-lightning/pull/3365), + [#3366](https://github.com/PyTorchLightning/pytorch-lightning/pull/3366), + [#3367](https://github.com/PyTorchLightning/pytorch-lightning/pull/3367), + [#3368](https://github.com/PyTorchLightning/pytorch-lightning/pull/3368), + [#3369](https://github.com/PyTorchLightning/pytorch-lightning/pull/3369), + [#3370](https://github.com/PyTorchLightning/pytorch-lightning/pull/3370), + [#3371](https://github.com/PyTorchLightning/pytorch-lightning/pull/3371), + [#3372](https://github.com/PyTorchLightning/pytorch-lightning/pull/3372), + [#3373](https://github.com/PyTorchLightning/pytorch-lightning/pull/3373), + [#3374](https://github.com/PyTorchLightning/pytorch-lightning/pull/3374), + [#3375](https://github.com/PyTorchLightning/pytorch-lightning/pull/3375), + [#3376](https://github.com/PyTorchLightning/pytorch-lightning/pull/3376), + [#3385](https://github.com/PyTorchLightning/pytorch-lightning/pull/3385), + [#3388](https://github.com/PyTorchLightning/pytorch-lightning/pull/3388), + [#3397](https://github.com/PyTorchLightning/pytorch-lightning/pull/3397)) + * all logging related calls in a connector ([#3395](https://github.com/PyTorchLightning/pytorch-lightning/pull/3395)) + * device parser ([#3400](https://github.com/PyTorchLightning/pytorch-lightning/pull/3400), + [#3405](https://github.com/PyTorchLightning/pytorch-lightning/pull/3405)) + * added model connector ([#3407](https://github.com/PyTorchLightning/pytorch-lightning/pull/3407)) + * moved eval loop logging to loggers ([#3408](https://github.com/PyTorchLightning/pytorch-lightning/pull/3408)) + * moved eval loop (#3412[#3408](https://github.com/PyTorchLightning/pytorch-lightning/pull/3408)) + * trainer/separate argparse ([#3421](https://github.com/PyTorchLightning/pytorch-lightning/pull/3421), + [#3428](https://github.com/PyTorchLightning/pytorch-lightning/pull/3428), + [#3432](https://github.com/PyTorchLightning/pytorch-lightning/pull/3432)) + * move `lr_finder` ([#3434](https://github.com/PyTorchLightning/pytorch-lightning/pull/3434)) + * organize args (#[#3435](https://github.com/PyTorchLightning/pytorch-lightning/pull/3435), + [#3442](https://github.com/PyTorchLightning/pytorch-lightning/pull/3442), + [#3447](https://github.com/PyTorchLightning/pytorch-lightning/pull/3447), + [#3448](https://github.com/PyTorchLightning/pytorch-lightning/pull/3448), + [#3449](https://github.com/PyTorchLightning/pytorch-lightning/pull/3449), + [#3456](https://github.com/PyTorchLightning/pytorch-lightning/pull/3456)) + * move specific accelerator code ([#3457](https://github.com/PyTorchLightning/pytorch-lightning/pull/3457)) + * group connectors ([#3472](https://github.com/PyTorchLightning/pytorch-lightning/pull/3472)) + * accelerator connector methods x/n ([#3469](https://github.com/PyTorchLightning/pytorch-lightning/pull/3469), + [#3470](https://github.com/PyTorchLightning/pytorch-lightning/pull/3470), + [#3474](https://github.com/PyTorchLightning/pytorch-lightning/pull/3474)) + * merge backends x/n ([#3476](https://github.com/PyTorchLightning/pytorch-lightning/pull/3476), + [#3477](https://github.com/PyTorchLightning/pytorch-lightning/pull/3477), + [#3478](https://github.com/PyTorchLightning/pytorch-lightning/pull/3478), + [#3480](https://github.com/PyTorchLightning/pytorch-lightning/pull/3480), + [#3482](https://github.com/PyTorchLightning/pytorch-lightning/pull/3482)) + * apex plugin ([#3502](https://github.com/PyTorchLightning/pytorch-lightning/pull/3502)) + * precision plugins ([#3504](https://github.com/PyTorchLightning/pytorch-lightning/pull/3504)) + * Result - make monitor default to `checkpoint_on` to simplify ([#3571](https://github.com/PyTorchLightning/pytorch-lightning/pull/3571)) + * reference to the Trainer on the `LightningDataModule` ([#3684](https://github.com/PyTorchLightning/pytorch-lightning/pull/3684)) + * add `.log` to lightning module ([#3686](https://github.com/PyTorchLightning/pytorch-lightning/pull/3686), + [#3699](https://github.com/PyTorchLightning/pytorch-lightning/pull/3699), + [#3701](https://github.com/PyTorchLightning/pytorch-lightning/pull/3701), + [#3704](https://github.com/PyTorchLightning/pytorch-lightning/pull/3704), + [#3715](https://github.com/PyTorchLightning/pytorch-lightning/pull/3715)) + * enable tracking original metric when step and epoch are both true ([#3685](https://github.com/PyTorchLightning/pytorch-lightning/pull/3685)) + * deprecated results obj, added support for simpler comms ([#3681](https://github.com/PyTorchLightning/pytorch-lightning/pull/3681)) + * move backends back to individual files ([#3712](https://github.com/PyTorchLightning/pytorch-lightning/pull/3712)) + * fixes logging for eval steps ([#3763](https://github.com/PyTorchLightning/pytorch-lightning/pull/3763)) + * decoupled DDP, DDP spawn ([#3733](https://github.com/PyTorchLightning/pytorch-lightning/pull/3733), + [#3766](https://github.com/PyTorchLightning/pytorch-lightning/pull/3766), + [#3767](https://github.com/PyTorchLightning/pytorch-lightning/pull/3767), + [#3774](https://github.com/PyTorchLightning/pytorch-lightning/pull/3774), + [#3802](https://github.com/PyTorchLightning/pytorch-lightning/pull/3802), + [#3806](https://github.com/PyTorchLightning/pytorch-lightning/pull/3806)) + * remove weight loading hack for ddp_cpu ([#3808](https://github.com/PyTorchLightning/pytorch-lightning/pull/3808)) + * separate `torchelastic` from DDP ([#3810](https://github.com/PyTorchLightning/pytorch-lightning/pull/3810)) + * separate SLURM from DDP ([#3809](https://github.com/PyTorchLightning/pytorch-lightning/pull/3809)) + * decoupled DDP2 ([#3816](https://github.com/PyTorchLightning/pytorch-lightning/pull/3816)) + * bug fix with logging val epoch end + monitor ([#3812](https://github.com/PyTorchLightning/pytorch-lightning/pull/3812)) + * decoupled DDP, DDP spawn ([#3733](https://github.com/PyTorchLightning/pytorch-lightning/pull/3733), + [#3817](https://github.com/PyTorchLightning/pytorch-lightning/pull/3817), + [#3819](https://github.com/PyTorchLightning/pytorch-lightning/pull/3819), + [#3927](https://github.com/PyTorchLightning/pytorch-lightning/pull/3927)) + * callback system and init DDP ([#3836](https://github.com/PyTorchLightning/pytorch-lightning/pull/3836)) + * adding compute environments ([#3837](https://github.com/PyTorchLightning/pytorch-lightning/pull/3837), [#3842](https://github.com/PyTorchLightning/pytorch-lightning/pull/3842)) + * epoch can now log independently ([#3843](https://github.com/PyTorchLightning/pytorch-lightning/pull/3843)) + * test selecting the correct backend. temp backends while slurm and TorchElastic are decoupled ([#3848](https://github.com/PyTorchLightning/pytorch-lightning/pull/3848)) + * fixed `init_slurm_connection` causing hostname errors ([#3856](https://github.com/PyTorchLightning/pytorch-lightning/pull/3856)) + * moves init apex from LM to apex connector ([#3923](https://github.com/PyTorchLightning/pytorch-lightning/pull/3923)) + * moves sync bn to each backend ([#3925](https://github.com/PyTorchLightning/pytorch-lightning/pull/3925)) + * moves configure ddp to each backend ([#3924](https://github.com/PyTorchLightning/pytorch-lightning/pull/3924)) +- Deprecation warning ([#3844](https://github.com/PyTorchLightning/pytorch-lightning/pull/3844)) +- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) +- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320)) + * Swaped `torch.load` for `fsspec` load in DDP spawn backend ([#3787](https://github.com/PyTorchLightning/pytorch-lightning/pull/3787)) + * Swaped `torch.load` for `fsspec` load in cloud_io loading ([#3692](https://github.com/PyTorchLightning/pytorch-lightning/pull/3692)) + * Added support for `to_disk()` to use remote filepaths with `fsspec` ([#3930](https://github.com/PyTorchLightning/pytorch-lightning/pull/3930)) + * Updated model_checkpoint's to_yaml to use `fsspec` open ([#3801](https://github.com/PyTorchLightning/pytorch-lightning/pull/3801)) + * Fixed `fsspec` is inconsistant when doing `fs.ls` ([#3805](https://github.com/PyTorchLightning/pytorch-lightning/pull/3805)) +- Refactor `GPUStatsMonitor` to improve training speed ([#3257](https://github.com/PyTorchLightning/pytorch-lightning/pull/3257)) +- Changed IoU score behavior for classes absent in target and pred ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098)) +- Changed IoU `remove_bg` bool to `ignore_index` optional int ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098)) +- Changed defaults of `save_top_k` and `save_last` to `None` in ModelCheckpoint ([#3680](https://github.com/PyTorchLightning/pytorch-lightning/pull/3680)) +- `row_log_interval` and `log_save_interval` are now based on training loop's `global_step` instead of epoch-internal batch index ([#3667](https://github.com/PyTorchLightning/pytorch-lightning/pull/3667)) +- Silenced some warnings. verified ddp refactors ([#3483](https://github.com/PyTorchLightning/pytorch-lightning/pull/3483)) +- Cleaning up stale logger tests ([#3490](https://github.com/PyTorchLightning/pytorch-lightning/pull/3490)) +- Allow `ModelCheckpoint` monitor to be `None` ([#3633](https://github.com/PyTorchLightning/pytorch-lightning/pull/3633)) +- Enable `None` model checkpoint default ([#3669](https://github.com/PyTorchLightning/pytorch-lightning/pull/3669)) +- Skipped `best_model_path` if `checkpoint_callback` is `None` ([#2962](https://github.com/PyTorchLightning/pytorch-lightning/pull/2962)) +- Used `raise .. from ..` to explicitly chain exceptions ([#3750](https://github.com/PyTorchLightning/pytorch-lightning/pull/3750)) +- Mocking loggers ([#3596](https://github.com/PyTorchLightning/pytorch-lightning/pull/3596), + [#3617](https://github.com/PyTorchLightning/pytorch-lightning/pull/3617), + [#3851](https://github.com/PyTorchLightning/pytorch-lightning/pull/3851), + [#3859](https://github.com/PyTorchLightning/pytorch-lightning/pull/3859), + [#3884](https://github.com/PyTorchLightning/pytorch-lightning/pull/3884), + [#3853](https://github.com/PyTorchLightning/pytorch-lightning/pull/3853), + [#3910](https://github.com/PyTorchLightning/pytorch-lightning/pull/3910), + [#3889](https://github.com/PyTorchLightning/pytorch-lightning/pull/3889), + [#3926](https://github.com/PyTorchLightning/pytorch-lightning/pull/3926)) +- Write predictions in LightningModule instead of EvalResult [#3882](https://github.com/PyTorchLightning/pytorch-lightning/pull/3882) + +### Deprecated + +- Deprecated `TrainResult` and `EvalResult`, use `self.log` and `self.write` from the `LightningModule` to log metrics and write predictions. `training_step` can now only return a scalar (for the loss) or a dictionary with anything you want. ([#3681](https://github.com/PyTorchLightning/pytorch-lightning/pull/3681)) +- Deprecate `early_stop_callback` Trainer argument ([#3845](https://github.com/PyTorchLightning/pytorch-lightning/pull/3845)) +- Rename Trainer arguments `row_log_interval` >> `log_every_n_steps` and `log_save_interval` >> `flush_logs_every_n_steps` ([#3748](https://github.com/PyTorchLightning/pytorch-lightning/pull/3748)) + +### Removed + +- Removed experimental Metric API ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868), + [#3943](https://github.com/PyTorchLightning/pytorch-lightning/pull/3943), + [#3949](https://github.com/PyTorchLightning/pytorch-lightning/pull/3949), + [#3946](https://github.com/PyTorchLightning/pytorch-lightning/pull/3946)), listed changes before final removal: + * Added `EmbeddingSimilarity` metric ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349), [#3358](https://github.com/PyTorchLightning/pytorch-lightning/pull/3358)) + * Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528)) + * Added error when AUROC metric is used for multiclass problems ([#3350](https://github.com/PyTorchLightning/pytorch-lightning/pull/3350)) + * Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) + * Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) + * Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517)) + * Fixed Metric aggregation ([#3321](https://github.com/PyTorchLightning/pytorch-lightning/pull/3321)) + * Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188)) + * Renamed `reduction` to `class_reduction` in classification metrics ([#3322](https://github.com/PyTorchLightning/pytorch-lightning/pull/3322)) + * Changed `class_reduction` similar to sklearn for classification metrics ([#3322](https://github.com/PyTorchLightning/pytorch-lightning/pull/3322)) + * Renaming of precision recall metric ([#3308](https://github.com/PyTorchLightning/pytorch-lightning/pull/3308)) + +### Fixed + +- Fixed `on_train_batch_start` hook to end epoch early ([#3700](https://github.com/PyTorchLightning/pytorch-lightning/pull/3700)) +- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917)) +- Fixed ONNX model save on GPU ([#3145](https://github.com/PyTorchLightning/pytorch-lightning/pull/3145)) +- Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008)) +- Fixed auto-scale batch size not dumping `auto_lr_find` parameter ([#3151](https://github.com/PyTorchLightning/pytorch-lightning/pull/3151)) +- Fixed `batch_outputs` with optimizer frequencies ([#3229](https://github.com/PyTorchLightning/pytorch-lightning/pull/3229)) +- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266)) +- Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404)) +- Fixed batch size auto scaling exceeding the size of the dataset ([#3271](https://github.com/PyTorchLightning/pytorch-lightning/pull/3271)) +- Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394)) +- Fixed `overfit_batches` which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501)) +- Fixed gradient norm tracking for `row_log_interval > 1` ([#3489](https://github.com/PyTorchLightning/pytorch-lightning/pull/3489)) +- Fixed `ModelCheckpoint` name formatting ([3164](https://github.com/PyTorchLightning/pytorch-lightning/pull/3163)) +- Fixed auto-scale batch size ([#3151](https://github.com/PyTorchLightning/pytorch-lightning/pull/3151)) +- Fixed example implementation of AutoEncoder ([#3190](https://github.com/PyTorchLightning/pytorch-lightning/pull/3190)) +- Fixed invalid paths when remote logging with TensorBoard ([#3236](https://github.com/PyTorchLightning/pytorch-lightning/pull/3236)) +- Fixed change `t()` to `transpose()` as XLA devices do not support `.t()` on 1-dim tensor ([#3252](https://github.com/PyTorchLightning/pytorch-lightning/pull/3252)) +- Fixed (weights only) checkpoints loading without PL ([#3287](https://github.com/PyTorchLightning/pytorch-lightning/pull/3287)) +- Fixed `gather_all_tensors` cross GPUs in DDP ([#3319](https://github.com/PyTorchLightning/pytorch-lightning/pull/3319)) +- Fixed CometML save dir ([#3419](https://github.com/PyTorchLightning/pytorch-lightning/pull/3419)) +- Fixed forward key metrics ([#3467](https://github.com/PyTorchLightning/pytorch-lightning/pull/3467)) +- Fixed normalize mode at confusion matrix (replace NaNs with zeros) ([#3465](https://github.com/PyTorchLightning/pytorch-lightning/pull/3465)) +- Fixed global step increment in training loop when `training_epoch_end` hook is used ([#3673](https://github.com/PyTorchLightning/pytorch-lightning/pull/3673)) +- Fixed dataloader shuffling not getting turned off with `overfit_batches > 0` and `distributed_backend = "ddp"` ([#3534](https://github.com/PyTorchLightning/pytorch-lightning/pull/3534)) +- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335)) +- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) +- Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751)) +- Fixed Tuner dump: add `current_epoch` to dumped_params ([#3261](https://github.com/PyTorchLightning/pytorch-lightning/pull/3261)) +- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785)) +- Fixed learning rate scheduler for optimizers with internal state ([#3897](https://github.com/PyTorchLightning/pytorch-lightning/pull/3897)) +- Fixed `tbptt_reduce_fx` when non-floating tensors are logged ([#3796](https://github.com/PyTorchLightning/pytorch-lightning/pull/3796)) +- Fixed model checkpoint frequency ([#3852](https://github.com/PyTorchLightning/pytorch-lightning/pull/3852)) +- Fixed logging non-tensor scalar with result breaks subsequent epoch aggregation ([#3855](https://github.com/PyTorchLightning/pytorch-lightning/pull/3855)) +- Fixed `TrainerEvaluationLoopMixin` activates `model.train()` at the end ([#3858](https://github.com/PyTorchLightning/pytorch-lightning/pull/3858)) +- Fixed `overfit_batches` when using with multiple val/test_dataloaders ([#3857](https://github.com/PyTorchLightning/pytorch-lightning/pull/3857)) +- Fixed enables `training_step` to return `None` ([#3862](https://github.com/PyTorchLightning/pytorch-lightning/pull/3862)) +- Fixed init nan for checkpointing ([#3863](https://github.com/PyTorchLightning/pytorch-lightning/pull/3863)) +- Fixed for `load_from_checkpoint` ([#2776](https://github.com/PyTorchLightning/pytorch-lightning/pull/2776)) +- Fixes incorrect `batch_sizes` when Dataloader returns a dict with multiple tensors ([#3668](https://github.com/PyTorchLightning/pytorch-lightning/pull/3668)) +- Fixed unexpected signature for `validation_step` ([#3947](https://github.com/PyTorchLightning/pytorch-lightning/pull/3947)) + +## [0.9.0] - 2020-08-20 + +### Added + +- Added SyncBN for DDP ([#2801](https://github.com/PyTorchLightning/pytorch-lightning/pull/2801), + [#2838](https://github.com/PyTorchLightning/pytorch-lightning/pull/2838)) +- Added basic `CSVLogger` ([#2721](https://github.com/PyTorchLightning/pytorch-lightning/pull/2721)) +- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671)) +- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535)) +- Added support to export a model to ONNX format ([#2596](https://github.com/PyTorchLightning/pytorch-lightning/pull/2596)) +- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246)) +- Added struct. output: + * tests for val loop flow ([#2605](https://github.com/PyTorchLightning/pytorch-lightning/pull/2605)) + * `EvalResult` support for train and val. loop ([#2615](https://github.com/PyTorchLightning/pytorch-lightning/pull/2615), + [#2651](https://github.com/PyTorchLightning/pytorch-lightning/pull/2651)) + * weighted average in results obj ([#2930](https://github.com/PyTorchLightning/pytorch-lightning/pull/2930)) + * fix result obj DP auto reduce ([#3013](https://github.com/PyTorchLightning/pytorch-lightning/pull/3013)) +- Added class `LightningDataModule` ([#2668](https://github.com/PyTorchLightning/pytorch-lightning/pull/2668)) +- Added support for PyTorch 1.6 ([#2745](https://github.com/PyTorchLightning/pytorch-lightning/pull/2745)) +- Added call DataModule hooks implicitly in trainer ([#2755](https://github.com/PyTorchLightning/pytorch-lightning/pull/2755)) +- Added support for Mean in DDP Sync ([#2568](https://github.com/PyTorchLightning/pytorch-lightning/pull/2568)) +- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562)) +- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2840](https://github.com/PyTorchLightning/pytorch-lightning/pull/2840)) +- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935)) +- Added support to Tensorboard logger for OmegaConf `hparams` ([#2846](https://github.com/PyTorchLightning/pytorch-lightning/pull/2846)) +- Added tracking of basic states in `Trainer` ([#2541](https://github.com/PyTorchLightning/pytorch-lightning/pull/2541)) +- Tracks all outputs including TBPTT and multiple optimizers ([#2890](https://github.com/PyTorchLightning/pytorch-lightning/pull/2890)) +- Added GPU Usage Logger ([#2932](https://github.com/PyTorchLightning/pytorch-lightning/pull/2932)) +- Added `strict=False` for `load_from_checkpoint` ([#2819](https://github.com/PyTorchLightning/pytorch-lightning/pull/2819)) +- Added saving test predictions on multiple GPUs ([#2926](https://github.com/PyTorchLightning/pytorch-lightning/pull/2926)) +- Auto log the computational graph for loggers that support this ([#3003](https://github.com/PyTorchLightning/pytorch-lightning/pull/3003)) +- Added warning when changing monitor and using results obj ([#3014](https://github.com/PyTorchLightning/pytorch-lightning/pull/3014)) +- Added a hook `transfer_batch_to_device` to the `LightningDataModule` ([#3038](https://github.com/PyTorchLightning/pytorch-lightning/pull/3038)) + +### Changed + +- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) +- Enabling val/test loop disabling ([#2692](https://github.com/PyTorchLightning/pytorch-lightning/pull/2692)) +- Refactored into `accelerator` module: + * GPU training ([#2704](https://github.com/PyTorchLightning/pytorch-lightning/pull/2704)) + * TPU training ([#2708](https://github.com/PyTorchLightning/pytorch-lightning/pull/2708)) + * DDP(2) backend ([#2796](https://github.com/PyTorchLightning/pytorch-lightning/pull/2796)) + * Retrieve last logged val from result by key ([#3049](https://github.com/PyTorchLightning/pytorch-lightning/pull/3049)) +- Using `.comet.config` file for `CometLogger` ([#1913](https://github.com/PyTorchLightning/pytorch-lightning/pull/1913)) +- Updated hooks arguments - breaking for `setup` and `teardown` ([#2850](https://github.com/PyTorchLightning/pytorch-lightning/pull/2850)) +- Using `gfile` to support remote directories ([#2164](https://github.com/PyTorchLightning/pytorch-lightning/pull/2164)) +- Moved optimizer creation after device placement for DDP backends ([#2904](https://github.com/PyTorchLightning/pytorch-lighting/pull/2904)) +- Support `**DictConfig` for `hparam` serialization ([#2519](https://github.com/PyTorchLightning/pytorch-lightning/pull/2519)) +- Removed callback metrics from test results obj ([#2994](https://github.com/PyTorchLightning/pytorch-lightning/pull/2994)) +- Re-enabled naming metrics in ckpt name ([#3060](https://github.com/PyTorchLightning/pytorch-lightning/pull/3060)) +- Changed progress bar epoch counting to start from 0 ([#3061](https://github.com/PyTorchLightning/pytorch-lightning/pull/3061)) + +### Deprecated + +- Deprecated Trainer attribute `ckpt_path`, which will now be set by `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) + +### Removed + +- Removed deprecated: ([#2760](https://github.com/PyTorchLightning/pytorch-lightning/pull/2760)) + * core decorator `data_loader` + * Module hook `on_sanity_check_start` and loading `load_from_metrics` + * package `pytorch_lightning.logging` + * Trainer arguments: `show_progress_bar`, `num_tpu_cores`, `use_amp`, `print_nan_grads` + * LR Finder argument `num_accumulation_steps` + +### Fixed + +- Fixed `accumulate_grad_batches` for last batch ([#2853](https://github.com/PyTorchLightning/pytorch-lightning/pull/2853)) +- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624)) +- Fixed local rank zero casting ([#2640](https://github.com/PyTorchLightning/pytorch-lightning/pull/2640)) +- Fixed single scalar return from training ([#2587](https://github.com/PyTorchLightning/pytorch-lightning/pull/2587)) +- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626)) +- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) +- Fixed `fast_dev_run` to run for all dataloaders ([#2581](https://github.com/PyTorchLightning/pytorch-lightning/pull/2581)) +- Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) +- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) +- Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632)) +- Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723)) +- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789)) +- Fixed logging interval ([#2694](https://github.com/PyTorchLightning/pytorch-lightning/pull/2694)) +- Fixed loss value in the progress bar is wrong when `accumulate_grad_batches > 1` ([#2738](https://github.com/PyTorchLightning/pytorch-lightning/pull/2738)) +- Fixed correct CWD for ddp sub-processes when using Hydra ([#2719](https://github.com/PyTorchLightning/pytorch-lightning/pull/2719)) +- Fixed selecting GPUs using `CUDA_VISIBLE_DEVICES` ([#2739](https://github.com/PyTorchLightning/pytorch-lightning/pull/2739), + [#2796](https://github.com/PyTorchLightning/pytorch-lightning/pull/2796)) +- Fixed false `num_classes` warning in metrics ([#2781](https://github.com/PyTorchLightning/pytorch-lightning/pull/2781)) +- Fixed shell injection vulnerability in subprocess call ([#2786](https://github.com/PyTorchLightning/pytorch-lightning/pull/2786)) +- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821)) +- Fixed `ModelCheckpoint` not saving the latest information when `save_last=True` ([#2881](https://github.com/PyTorchLightning/pytorch-lightning/pull/2881)) +- Fixed ImageNet example: learning rate scheduler, number of workers and batch size when using DDP ([#2889](https://github.com/PyTorchLightning/pytorch-lightning/pull/2889)) +- Fixed apex gradient clipping ([#2829](https://github.com/PyTorchLightning/pytorch-lightning/pull/2829)) +- Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828)) +- Fixed a model loading issue with inheritance and variable positional arguments ([#2911](https://github.com/PyTorchLightning/pytorch-lightning/pull/2911)) +- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910)) +- Fixed checkpointing to remote file paths ([#2925](https://github.com/PyTorchLightning/pytorch-lightning/pull/2925)) +- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986)) +- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997)) +- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020)) +- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043)) +- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045)) +- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042)) + +## [0.8.5] - 2020-07-09 + +### Added + +- Added a PSNR metric: peak signal-to-noise ratio ([#2483](https://github.com/PyTorchLightning/pytorch-lightning/pull/2483)) +- Added functional regression metrics ([#2492](https://github.com/PyTorchLightning/pytorch-lightning/pull/2492)) + +### Removed + +- Removed auto val reduce ([#2462](https://github.com/PyTorchLightning/pytorch-lightning/pull/2462)) + +### Fixed + +- Flattening Wandb Hyperparameters ([#2459](https://github.com/PyTorchLightning/pytorch-lightning/pull/2459)) +- Fixed using the same DDP python interpreter and actually running ([#2482](https://github.com/PyTorchLightning/pytorch-lightning/pull/2482)) +- Fixed model summary input type conversion for models that have input dtype different from model parameters ([#2510](https://github.com/PyTorchLightning/pytorch-lightning/pull/2510)) +- Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518)) +- Fixed a problem with `MLflowLogger` creating multiple run folders ([#2502](https://github.com/PyTorchLightning/pytorch-lightning/pull/2502)) +- Fixed global_step increment ([#2455](https://github.com/PyTorchLightning/pytorch-lightning/pull/2455)) +- Fixed TPU hanging example ([#2488](https://github.com/PyTorchLightning/pytorch-lightning/pull/2488)) +- Fixed `argparse` default value bug ([#2526](https://github.com/PyTorchLightning/pytorch-lightning/pull/2526)) +- Fixed Dice and IoU to avoid NaN by adding small eps ([#2545](https://github.com/PyTorchLightning/pytorch-lightning/pull/2545)) +- Fixed accumulate gradients schedule at epoch 0 (continued) ([#2513](https://github.com/PyTorchLightning/pytorch-lightning/pull/2513)) +- Fixed Trainer `.fit()` returning last not best weights in "ddp_spawn" ([#2565](https://github.com/PyTorchLightning/pytorch-lightning/pull/2565)) +- Fixed passing (do not pass) TPU weights back on test ([#2566](https://github.com/PyTorchLightning/pytorch-lightning/pull/2566)) +- Fixed DDP tests and `.test()` ([#2512](https://github.com/PyTorchLightning/pytorch-lightning/pull/2512), + [#2570](https://github.com/PyTorchLightning/pytorch-lightning/pull/2570)) + +## [0.8.4] - 2020-07-01 + +### Added + +- Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434)) +- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437)) + +### Changed + +- Enabled no returns from eval ([#2446](https://github.com/PyTorchLightning/pytorch-lightning/pull/2446)) + +### Fixed + +- Fixes train outputs ([#2428](https://github.com/PyTorchLightning/pytorch-lightning/pull/2428)) +- Fixes Conda dependencies ([#2412](https://github.com/PyTorchLightning/pytorch-lightning/pull/2412)) +- Fixed Apex scaling with decoupled backward ([#2433](https://github.com/PyTorchLightning/pytorch-lightning/pull/2433)) +- Fixed crashing or wrong displaying progressbar because of missing ipywidgets ([#2417](https://github.com/PyTorchLightning/pytorch-lightning/pull/2417)) +- Fixed TPU saving dir ([fc26078e](https://github.com/PyTorchLightning/pytorch-lightning/commit/fc26078e395f8a001f4c6dd7b3fe7ca202f914a3), [04e68f02](https://github.com/PyTorchLightning/pytorch-lightning/commit/04e68f022fc03dd5f1555ee86dea997d42a448ad)) +- Fixed logging on rank 0 only ([#2425](https://github.com/PyTorchLightning/pytorch-lightning/pull/2425)) + + +## [0.8.3] - 2020-06-29 + +### Fixed + +- Fixed AMP wrong call ([593837e](https://github.com/PyTorchLightning/pytorch-lightning/commit/593837e1da24ff6c942b24ed803fc1496a304609)) +- Fixed batch typo ([92d1e75](https://github.com/PyTorchLightning/pytorch-lightning/commit/92d1e75b2638a493d9d21ed5fe00a22093888285)) + +## [0.8.2] - 2020-06-28 + +### Added + +- Added TorchText support for moving data to GPU ([#2379](https://github.com/PyTorchLightning/pytorch-lightning/pull/2379)) + +### Changed + +- Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289)) +- Refactor Model `backward` ([#2276](https://github.com/PyTorchLightning/pytorch-lightning/pull/2276)) +- Refactored `training_batch` + tests to verify correctness ([#2327](https://github.com/PyTorchLightning/pytorch-lightning/pull/2327), + [#2328](https://github.com/PyTorchLightning/pytorch-lightning/pull/2328)) +- Refactored training loop ([#2336](https://github.com/PyTorchLightning/pytorch-lightning/pull/2336)) +- Made optimization steps for hooks ([#2363](https://github.com/PyTorchLightning/pytorch-lightning/pull/2363)) +- Changed default apex level to 'O2' ([#2362](https://github.com/PyTorchLightning/pytorch-lightning/pull/2362)) + +### Removed + +- Moved `TrainsLogger` to Bolts ([#2384](https://github.com/PyTorchLightning/pytorch-lightning/pull/2384)) + +### Fixed + +- Fixed parsing TPU arguments and TPU tests ([#2094](https://github.com/PyTorchLightning/pytorch-lightning/pull/2094)) +- Fixed number batches in case of multiple dataloaders and `limit_{*}_batches` ([#1920](https://github.com/PyTorchLightning/pytorch-lightning/pull/1920), + [#2226](https://github.com/PyTorchLightning/pytorch-lightning/pull/2226)) +- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298)) +- Fix for `load_from_checkpoint()` not working with absolute path on Windows ([#2294](https://github.com/PyTorchLightning/pytorch-lightning/pull/2294)) +- Fixed an issue how _has_len handles `NotImplementedError` e.g. raised by `torchtext.data.Iterator` ([#2293](https://github.com/PyTorchLightning/pytorch-lightning/pull/2293)), ([#2307](https://github.com/PyTorchLightning/pytorch-lightning/pull/2307)) +- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) +- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304)) +- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) +- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335)) +- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387)) +- Fixed sum(0) for `trainer.num_val_batches` ([#2268](https://github.com/PyTorchLightning/pytorch-lightning/pull/2268)) +- Fixed checking if the parameters are a `DictConfig` Object ([#2216](https://github.com/PyTorchLightning/pytorch-lightning/pull/2216)) +- Fixed SLURM weights saving ([#2341](https://github.com/PyTorchLightning/pytorch-lightning/pull/2341)) +- Fixed swaps LR scheduler order ([#2356](https://github.com/PyTorchLightning/pytorch-lightning/pull/2356)) +- Fixed adding tensorboard `hparams` logging test ([#2342](https://github.com/PyTorchLightning/pytorch-lightning/pull/2342)) +- Fixed use model ref for tear down ([#2360](https://github.com/PyTorchLightning/pytorch-lightning/pull/2360)) +- Fixed logger crash on DDP ([#2388](https://github.com/PyTorchLightning/pytorch-lightning/pull/2388)) +- Fixed several issues with early stopping and checkpoint callbacks ([#1504](https://github.com/PyTorchLightning/pytorch-lightning/pull/1504), + [#2391](https://github.com/PyTorchLightning/pytorch-lightning/pull/2391)) +- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405)) +- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403)) +- Fixed Windows compatibility issue ([#2358](https://github.com/PyTorchLightning/pytorch-lightning/pull/2358)) + +## [0.8.1] - 2020-06-19 + +### Fixed + +- Fixed the `load_from_checkpoint` path detected as URL bug ([#2244](https://github.com/PyTorchLightning/pytorch-lightning/pull/2244)) +- Fixed hooks - added barrier ([#2245](https://github.com/PyTorchLightning/pytorch-lightning/pull/2245), + [#2257](https://github.com/PyTorchLightning/pytorch-lightning/pull/2257), + [#2260](https://github.com/PyTorchLightning/pytorch-lightning/pull/220)) +- Fixed `hparams` - remove frame inspection on `self.hparams` ([#2253](https://github.com/PyTorchLightning/pytorch-lightning/pull/2253)) +- Fixed setup and on fit calls ([#2252](https://github.com/PyTorchLightning/pytorch-lightning/pull/2252)) +- Fixed GPU template ([#2255](https://github.com/PyTorchLightning/pytorch-lightning/pull/2255)) + +## [0.8.0] - 2020-06-18 + +### Added + +- Added `overfit_batches`, `limit_{val|test}_batches` flags (overfit now uses training set for all three) ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213)) +- Added metrics + * Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), + [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) + * Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327)) + * Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488), + [#2062](https://github.com/PyTorchLightning/pytorch-lightning/pull/2062)) + * docs for all Metrics ([#2184](https://github.com/PyTorchLightning/pytorch-lightning/pull/2184), + [#2209](https://github.com/PyTorchLightning/pytorch-lightning/pull/2209)) + * Regression metrics ([#2221](https://github.com/PyTorchLightning/pytorch-lightning/pull/2221)) +- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)) +- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907)) +- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` ([#1908](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)) +- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458)) +- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) +- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033)) +- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)) +- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610)) +- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115)) +- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/pull/1667)) +- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134)) +- Added a decorator `auto_move_data` that moves data to the correct device when using the LightningModule for inference ([#1905](https://github.com/PyTorchLightning/pytorch-lightning/pull/1905)) +- Added `ckpt_path` option to `LightningModule.test(...)` to load particular checkpoint ([#2190](https://github.com/PyTorchLightning/pytorch-lightning/pull/2190)) +- Added `setup` and `teardown` hooks for model ([#2229](https://github.com/PyTorchLightning/pytorch-lightning/pull/2229)) + +### Changed + +- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) +- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862)) +- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896)) +- Renamed `ModelCheckpoint`'s attributes `best` to `best_model_score` and `kth_best_model` to `kth_best_model_path` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) +- Re-Enable Logger's `ImportError`s ([#1938](https://github.com/PyTorchLightning/pytorch-lightning/pull/1938)) +- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029)) +- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020)) +- Enabled `prepare_data` from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166)) +- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126)) +- Changed epoch indexing from 1 instead of 0 ([#2206](https://github.com/PyTorchLightning/pytorch-lightning/pull/2206)) + +### Deprecated + +- Deprecated flags: ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213)) + * `overfit_pct` in favour of `overfit_batches` + * `val_percent_check` in favour of `limit_val_batches` + * `test_percent_check` in favour of `limit_test_batches` +- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) +- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917)) +- Deprecated Trainer `proc_rank` in favour of `global_rank` ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166), + [#2269](https://github.com/PyTorchLightning/pytorch-lightning/pull/2269)) + +### Removed + +- Removed unintended Trainer argument `progress_bar_callback`, the callback should be passed in by `Trainer(callbacks=[...])` instead ([#1855](https://github.com/PyTorchLightning/pytorch-lightning/pull/1855)) +- Removed obsolete `self._device` in Trainer ([#1849](https://github.com/PyTorchLightning/pytorch-lightning/pull/1849)) +- Removed deprecated API ([#2073](https://github.com/PyTorchLightning/pytorch-lightning/pull/2073)) + * Packages: `pytorch_lightning.pt_overrides`, `pytorch_lightning.root_module` + * Modules: `pytorch_lightning.logging.comet_logger`, `pytorch_lightning.logging.mlflow_logger`, `pytorch_lightning.logging.test_tube_logger`, `pytorch_lightning.overrides.override_data_parallel`, `pytorch_lightning.core.model_saving`, `pytorch_lightning.core.root_module` + * Trainer arguments: `add_row_log_interval`, `default_save_path`, `gradient_clip`, `nb_gpu_nodes`, `max_nb_epochs`, `min_nb_epochs`, `nb_sanity_val_steps` + * Trainer attributes: `nb_gpu_nodes`, `num_gpu_nodes`, `gradient_clip`, `max_nb_epochs`, `min_nb_epochs`, `nb_sanity_val_steps`, `default_save_path`, `tng_tqdm_dic` + +### Fixed + +- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631)) +- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873)) +- Fixed multiple calls of `EarlyStopping` callback ([#1863](https://github.com/PyTorchLightning/pytorch-lightning/pull/1863)) +- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932)) +- Fixed bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933)) +- Fixed root node resolution for SLURM cluster with dash in host name ([#1954](https://github.com/PyTorchLightning/pytorch-lightning/pull/1954)) +- Fixed `LearningRateLogger` in multi-scheduler setting ([#1944](https://github.com/PyTorchLightning/pytorch-lightning/pull/1944)) +- Fixed test configuration check and testing ([#1804](https://github.com/PyTorchLightning/pytorch-lightning/pull/1804)) +- Fixed an issue with Trainer constructor silently ignoring unknown/misspelled arguments ([#1820](https://github.com/PyTorchLightning/pytorch-lightning/pull/1820)) +- Fixed `save_weights_only` in ModelCheckpoint ([#1780](https://github.com/PyTorchLightning/pytorch-lightning/pull/1780)) +- Allow use of same `WandbLogger` instance for multiple training loops ([#2055](https://github.com/PyTorchLightning/pytorch-lightning/pull/2055)) +- Fixed an issue with `_auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048)) +- Fixed mistake in parameters' grad norm tracking ([#2012](https://github.com/PyTorchLightning/pytorch-lightning/pull/2012)) +- Fixed CPU and hanging GPU crash ([#2118](https://github.com/PyTorchLightning/pytorch-lightning/pull/2118)) +- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773)) +- Fixed Tpu logging ([#2230](https://github.com/PyTorchLightning/pytorch-lightning/pull/2230)) +- Fixed Pid port + duplicate `rank_zero` logging ([#2140](https://github.com/PyTorchLightning/pytorch-lightning/pull/2140), + [#2231](https://github.com/PyTorchLightning/pytorch-lightning/pull/2231)) + +## [0.7.6] - 2020-05-16 + +### Added + +- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498)) +- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564)) +- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). +- Added auto scaling of batch size ([#1638](https://github.com/PyTorchLightning/pytorch-lightning/pull/1638)) +- The progress bar metrics now also get updated in `training_epoch_end` ([#1724](https://github.com/PyTorchLightning/pytorch-lightning/pull/1724)) +- Enable `NeptuneLogger` to work with `distributed_backend=ddp` ([#1753](https://github.com/PyTorchLightning/pytorch-lightning/pull/1753)) +- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572)) +- Added override for hparams in `load_from_ckpt` ([#1797](https://github.com/PyTorchLightning/pytorch-lightning/pull/1797)) +- Added support multi-node distributed execution under `torchelastic` ([#1811](https://github.com/PyTorchLightning/pytorch-lightning/pull/1811), + [#1818](https://github.com/PyTorchLightning/pytorch-lightning/pull/1818)) +- Added using `store_true` for bool args ([#1822](https://github.com/PyTorchLightning/pytorch-lightning/pull/1822), + [#1842](https://github.com/PyTorchLightning/pytorch-lightning/pull/1842)) +- Added dummy logger for internally disabling logging for some features ([#1836](https://github.com/PyTorchLightning/pytorch-lightning/pull/1836)) + +### Changed + +- Enable `non-blocking` for device transfers to GPU ([#1843](https://github.com/PyTorchLightning/pytorch-lightning/pull/1843)) +- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271)) +- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) +- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577)) +- Don't convert `namedtuple` to `tuple` when transferring the batch to target device ([#1589](https://github.com/PyTorchLightning/pytorch-lightning/pull/1589)) +- Allow passing hparams as keyword argument to LightningModule when loading from checkpoint ([#1639](https://github.com/PyTorchLightning/pytorch-lightning/pull/1639)) +- Args should come after the last positional argument ([#1807](https://github.com/PyTorchLightning/pytorch-lightning/pull/1807)) +- Made ddp the default if no backend specified with multiple GPUs ([#1789](https://github.com/PyTorchLightning/pytorch-lightning/pull/1789)) + +### Deprecated + +- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271)) + +### Fixed + +- Fixed broken link in PR template ([#1675](https://github.com/PyTorchLightning/pytorch-lightning/pull/1675)) +- Fixed ModelCheckpoint not None checking filepath ([#1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654)) +- Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([#1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666)) +- Fixed sampler logic for ddp with iterable dataset ([#1734](https://github.com/PyTorchLightning/pytorch-lightning/pull/1734)) +- Fixed `_reset_eval_dataloader()` for IterableDataset ([#1560](https://github.com/PyTorchLightning/pytorch-lightning/pull/1560)) +- Fixed Horovod distributed backend to set the `root_gpu` property ([#1669](https://github.com/PyTorchLightning/pytorch-lightning/pull/1669)) +- Fixed wandb logger `global_step` affects other loggers ([#1492](https://github.com/PyTorchLightning/pytorch-lightning/pull/1492)) +- Fixed disabling progress bar on non-zero ranks using Horovod backend ([#1709](https://github.com/PyTorchLightning/pytorch-lightning/pull/1709)) +- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676)) - Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748)) +- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719)) +- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561)) +- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801)) +- Fixed num processes wasn't being set properly and auto sampler was ddp failing ([#1819](https://github.com/PyTorchLightning/pytorch-lightning/pull/1819)) +- Fixed bugs in semantic segmentation example ([#1824](https://github.com/PyTorchLightning/pytorch-lightning/pull/1824)) +- Fixed saving native AMP scaler state ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), + [#1777](https://github.com/PyTorchLightning/pytorch-lightning/pull/1777)) +- Fixed native amp + ddp ([#1788](https://github.com/PyTorchLightning/pytorch-lightning/pull/1788)) +- Fixed `hparam` logging with metrics ([#1647](https://github.com/PyTorchLightning/pytorch-lightning/pull/1647)) ## [0.7.5] - 2020-04-27 @@ -74,7 +1543,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158)) - Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529)) - Added support for 8 core distributed training on Kaggle TPU's ([#1568](https://github.com/PyTorchLightning/pytorch-lightning/pull/1568)) -- Added support for native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), [#1580](https://github.com/PyTorchLightning/pytorch-lightning/pull/1580)) +- Added support for native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), + [#1580](https://github.com/PyTorchLightning/pytorch-lightning/pull/1580)) ### Changed @@ -105,7 +1575,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allow use of sweeps with `WandbLogger` ([#1512](https://github.com/PyTorchLightning/pytorch-lightning/pull/1512)) - Fixed a bug that caused the `callbacks` Trainer argument to reference a global variable ([#1534](https://github.com/PyTorchLightning/pytorch-lightning/pull/1534)). - Fixed a bug that set all boolean CLI arguments from `Trainer.add_argparse_args` always to True ([#1571](https://github.com/PyTorchLightning/pytorch-lightning/pull/1571)) -- Fixed do not copy the batch when training on a single GPU ([#1576](https://github.com/PyTorchLightning/pytorch-lightning/pull/1576), [#1579](https://github.com/PyTorchLightning/pytorch-lightning/pull/1579)) +- Fixed do not copy the batch when training on a single GPU ([#1576](https://github.com/PyTorchLightning/pytorch-lightning/pull/1576), + [#1579](https://github.com/PyTorchLightning/pytorch-lightning/pull/1579)) - Fixed soft checkpoint removing on DDP ([#1408](https://github.com/PyTorchLightning/pytorch-lightning/pull/1408)) - Fixed automatic parser bug ([#1585](https://github.com/PyTorchLightning/pytorch-lightning/pull/1585)) - Fixed bool conversion from string ([#1606](https://github.com/PyTorchLightning/pytorch-lightning/pull/1606)) @@ -195,11 +1666,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)). - Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)) - Fixed `WandbLogger.watch` - use of the watch method without importing `wandb` ([#1311](https://github.com/PyTorchLightning/pytorch-lightning/pull/1311)) -- Fixed `WandbLogger` to be used with 'ddp' - allow reinits in sub-processes ([#1149](https://github.com/PyTorchLightning/pytorch-lightning/pull/1149), [#1360](https://github.com/PyTorchLightning/pytorch-lightning/pull/1360)) +- Fixed `WandbLogger` to be used with 'ddp' - allow reinits in sub-processes ([#1149](https://github.com/PyTorchLightning/pytorch-lightning/pull/1149), + [#1360](https://github.com/PyTorchLightning/pytorch-lightning/pull/1360)) - Made `training_epoch_end` behave like `validation_epoch_end` ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357)) - Fixed `fast_dev_run` running validation twice ([#1365](https://github.com/PyTorchLightning/pytorch-lightning/pull/1365)) - Fixed pickle error from quick patch `__code__` ([#1352](https://github.com/PyTorchLightning/pytorch-lightning/pull/1352)) -- Fixed memory leak on GPU0 ([#1094](https://github.com/PyTorchLightning/pytorch-lightning/pull/1094), [#1349](https://github.com/PyTorchLightning/pytorch-lightning/pull/1349)) +- Fixed memory leak on GPU0 ([#1094](https://github.com/PyTorchLightning/pytorch-lightning/pull/1094), + [#1349](https://github.com/PyTorchLightning/pytorch-lightning/pull/1349)) - Fixed checkpointing interval ([#1272](https://github.com/PyTorchLightning/pytorch-lightning/pull/1272)) - Fixed validation and training loops run the partial dataset ([#1192](https://github.com/PyTorchLightning/pytorch-lightning/pull/1192)) - Fixed running `on_validation_end` only on main process in DDP ([#1125](https://github.com/PyTorchLightning/pytorch-lightning/pull/1125)) @@ -231,16 +1704,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added option to specify `step` key when logging metrics ([#808](https://github.com/PyTorchLightning/pytorch-lightning/pull/808)) - Added `train_dataloader`, `val_dataloader` and `test_dataloader` arguments to `Trainer.fit()`, for alternative data parsing ([#759](https://github.com/PyTorchLightning/pytorch-lightning/pull/759)) - Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868)) -- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876), [#881](https://github.com/PyTorchLightning/pytorch-lightning/pull/881)) +- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876), + [#881](https://github.com/PyTorchLightning/pytorch-lightning/pull/881)) - Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849)) - Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950)) - Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) - Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941)) - Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) - Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041)) -- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019)) +- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), + [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019)) - Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), ) -- Added default `argparser` for `Trainer` ([#952](https://github.com/PyTorchLightning/pytorch-lightning/pull/1023), [#1023](https://github.com/PyTorchLightning/pytorch-lightning/pull/1023)) +- Added default `argparser` for `Trainer` ([#952](https://github.com/PyTorchLightning/pytorch-lightning/pull/1023), + [#1023](https://github.com/PyTorchLightning/pytorch-lightning/pull/1023)) - Added TPU gradient clipping ([#963](https://github.com/PyTorchLightning/pytorch-lightning/pull/963)) - Added max/min number of steps in `Trainer` ([#728](https://github.com/PyTorchLightning/pytorch-lightning/pull/728)) @@ -264,9 +1740,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated - Deprecated `pytorch_lightning.logging` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767)) -- Deprecated `LightningModule.load_from_metrics` in favour of `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995), [#1079](https://github.com/PyTorchLightning/pytorch-lightning/pull/1079)) +- Deprecated `LightningModule.load_from_metrics` in favour of `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995), + [#1079](https://github.com/PyTorchLightning/pytorch-lightning/pull/1079)) - Deprecated `@data_loader` decorator ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) -- Deprecated model steps `training_end`, `validation_end` and `test_end` ([#1051](https://github.com/PyTorchLightning/pytorch-lightning/pull/1051), [#1056](https://github.com/PyTorchLightning/pytorch-lightning/pull/1056)) +- Deprecated model steps `training_end`, `validation_end` and `test_end` ([#1051](https://github.com/PyTorchLightning/pytorch-lightning/pull/1051), + [#1056](https://github.com/PyTorchLightning/pytorch-lightning/pull/1056)) ### Removed @@ -582,16 +2060,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `Experiment` object was not process safe, potentially causing logs to be overwritten -## [0.3.5] - 2019-MM-DD +## [0.3.5] - 2019-07-25 -## [0.3.4] - 2019-MM-DD +## [0.3.4] - 2019-07-22 -## [0.3.3] - 2019-MM-DD +## [0.3.3] - 2019-07-22 -## [0.3.2] - 2019-MM-DD +## [0.3.2] - 2019-07-21 -## [0.3.1] - 2019-MM-DD +## [0.3.1] - 2019-07-21 -## [0.2.x] - YYYY-MM-DD +## [0.2.x] - 2019-07-09 -## [0.1.x] - YYYY-MM-DD +## [0.1.x] - 2019-06-DD diff --git a/LICENSE b/LICENSE index b9181e1a6e5d83..2e66bec2e791c2 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2018-2020 William Falcon + Copyright 2018-2021 William Falcon Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in index bae6726c70bf01..b1e7613831fe82 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html graft wheelhouse @@ -10,12 +24,14 @@ include *.md include LICENSE exclude *.sh -exclude *.toml -exclude *.svg +exclude *.svg recursive-include pytorch_lightning *.py +# Include marker file for PEP 561 +include pytorch_lightning/py.typed + # include examples -recursive-include pl_examples *.py *.md *.sh *.txt +recursive-include pl_examples *.py *.md *.sh *.txt *.toml # exclude tests from package recursive-exclude tests * @@ -25,15 +41,26 @@ exclude tests # Exclude the documentation files recursive-exclude docs * exclude docs -recursive-include docs/source/_images/logos/ * -recursive-include docs/source/_images/general/ pl_overview* tf_* +recursive-include docs/source/_static/images/logos/ * +recursive-include docs/source/_static/images/general/ pl_overview* tf_* tutorial_* PTL101_* # Include the Requirements +recursive-include requirements *.txt +recursive-exclude requirements *.sh *.py include requirements.txt -include requirements-extra.txt +include pyproject.toml # Exclude build configs exclude *.yml +exclude *.yaml +exclude *.jsonnet +exclude .yapfignore + +# Exclude pyright config +exclude .pyrightconfig.json + +# Exclude Makefile +exclude Makefile prune .git prune .github @@ -42,5 +69,5 @@ prune notebook* prune temp* prune test* prune benchmark* -prune docker - +prune dockers +prune legacy diff --git a/Makefile b/Makefile new file mode 100644 index 00000000000000..04b08fa2d27d1a --- /dev/null +++ b/Makefile @@ -0,0 +1,32 @@ +.PHONY: test clean docs + +# to imitate SLURM set only single node +export SLURM_LOCALID=0 +# assume you have installed need packages +export SPHINX_MOCK_REQUIREMENTS=1 + +clean: + # clean all temp runs + rm -rf $(shell find . -name "mlruns") + rm -rf $(shell find . -name "lightning_log") + rm -rf $(shell find . -name "lightning_logs") + rm -rf _ckpt_* + rm -rf .mypy_cache + rm -rf .pytest_cache + rm -rf ./docs/build + rm -rf ./docs/source/generated + rm -rf ./docs/source/*/generated + rm -rf ./docs/source/api + +test: clean + # Review the CONTRIBUTING docmentation for other ways to test. + pip install -r requirements/devel.txt + # install APEX, see https://github.com/NVIDIA/apex#linux + + # run tests with coverage + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests pl_examples -v + python -m coverage report + +docs: clean + pip install --quiet -r requirements/docs.txt + python -m sphinx -b html -W --keep-going docs/source docs/build diff --git a/README.md b/README.md index afe61c131e2ca3..fd4e3219fbed4e 100644 --- a/README.md +++ b/README.md @@ -1,402 +1,424 @@
-![Logo](docs/source/_images/logos/lightning_logo.svg) + -# PyTorch Lightning -**The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.** +**The lightweight PyTorch wrapper for high-performance AI research. +Scale your models, not the boilerplate.** +--- +

+ Website • + Key Features • + How To Use • + Docs • + Examples • + Community • + Grid AI • + License +

+ + +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-lightning)](https://pypi.org/project/pytorch-lightning/) [![PyPI Status](https://badge.fury.io/py/pytorch-lightning.svg)](https://badge.fury.io/py/pytorch-lightning) [![PyPI Status](https://pepy.tech/badge/pytorch-lightning)](https://pepy.tech/project/pytorch-lightning) +[![Conda](https://img.shields.io/conda/v/conda-forge/pytorch-lightning?label=conda&color=success)](https://anaconda.org/conda-forge/pytorch-lightning) +[![DockerHub](https://img.shields.io/docker/pulls/pytorchlightning/pytorch_lightning.svg)](https://hub.docker.com/r/pytorchlightning/pytorch_lightning) [![codecov](https://codecov.io/gh/PyTorchLightning/pytorch-lightning/branch/master/graph/badge.svg)](https://codecov.io/gh/PyTorchLightning/pytorch-lightning) -[![CodeFactor](https://www.codefactor.io/repository/github/pytorchlightning/pytorch-lightning/badge)](https://www.codefactor.io/repository/github/pytorchlightning/pytorch-lightning) -[![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=0.7.5)](https://pytorch-lightning.readthedocs.io/en/stable/) -[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-lightning/shared_invite/enQtODU5ODIyNTUzODQwLTFkMDg5Mzc1MDBmNjEzMDgxOTVmYTdhYjA1MDdmODUyOTg2OGQ1ZWZkYTQzODhhNzdhZDA3YmNhMDhlMDY4YzQ) +[![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=stable)](https://pytorch-lightning.readthedocs.io/en/stable/) +[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/PytorchLightning/pytorch-lightning/blob/master/LICENSE) -[![Next Release](https://img.shields.io/badge/Next%20Release-May%2006-.svg)](https://shields.io/)
+###### *Codecov is > 90%+ but build delays may show less + --- -## Continuous Integration -
-| System / PyTorch ver. | 1.1 (min. reg) | 1.2 | 1.3 | 1.4 | 1.5 (latest) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Linux py3.6 [CPU] | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | [![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning) | -| Linux py3.7 [GPU] | - | - | - | - | [![Build Status](http://35.192.60.23/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://35.192.60.23/PyTorchLightning/pytorch-lightning) | -| Linux py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | -| OSX py3.6 / py3.7 / py3.8| [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | -| Windows py3.6 / py3.7 / py3.8 | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | +## PyTorch Lightning is just organized PyTorch +Lightning disentangles PyTorch code to decouple the science from the engineering. +![PT to PL](docs/source/_static/images/general/pl_quick_start_full_compressed.gif) -
+--- -Simple installation from PyPI -```bash -pip install pytorch-lightning -``` +## Lightning Design Philosophy +Lightning structures PyTorch code with these principles: -## Docs -- [master](https://pytorch-lightning.readthedocs.io/en/latest) -- [0.7.5](https://pytorch-lightning.readthedocs.io/en/0.7.5/) -- [0.7.3](https://pytorch-lightning.readthedocs.io/en/0.7.3/) -- [0.7.1](https://pytorch-lightning.readthedocs.io/en/0.7.1/) -- [0.6.0](https://pytorch-lightning.readthedocs.io/en/0.6.0/) -- [0.5.3.2](https://pytorch-lightning.readthedocs.io/en/0.5.3.2/) +
+ +
-## Demo -[MNIST, GAN, BERT, DQN on COLAB!](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg) -[MNIST on TPUs](https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3) +Lightning forces the following structure to your code which makes it reusable and shareable: -## What is it? -[READ THIS QUICK START PAGE](https://pytorch-lightning.readthedocs.io/en/stable/new-project.html) +- Research code (the LightningModule). +- Engineering code (you delete, and is handled by the Trainer). +- Non-essential research code (logging, etc... this goes in Callbacks). +- Data (use PyTorch Dataloaders or organize them into a LightningDataModule). -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. -It's more of a PyTorch style-guide than a framework. +Once you do this, you can train on multiple-GPUs, TPUs, CPUs and even in 16-bit precision without changing your code! -In Lightning, you organize your code into 3 distinct categories: +Get started with our [2 step guide](https://pytorch-lightning.readthedocs.io/en/latest/starter/new-project.html) -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc... this goes in Callbacks). +--- -Here's an example of how to refactor your research code into a [LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html). +## Continuous Integration +Lightning is rigurously tested across multiple GPUs, TPUs CPUs and against major Python and PyTorch versions. -![PT to PL](docs/source/_images/lightning_module/pt_to_pl.png) +
+ Current build statuses -The rest of the code is automated by the [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html)! -![PT to PL](docs/source/_images/lightning_module/pt_trainer.png) +
-## Testing Rigour -All the automated code by the Trainer is [tested rigorously with every new PR](https://github.com/PyTorchLightning/pytorch-lightning/tree/master/tests). + | System / PyTorch ver. | 1.4 (min. req.)* | 1.5 | 1.6 | 1.7 (latest) | 1.8 (nightly) | + | :---: | :---: | :---: | :---: | :---: | :---: | + | Conda py3.7 [linux] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | + | Linux py3.7 [GPUs**] | - | - | [![Build Status](https://dev.azure.com/PytorchLightning/pytorch-lightning/_apis/build/status/PyTorchLightning.pytorch-lightning?branchName=master)](https://dev.azure.com/PytorchLightning/pytorch-lightning/_build/latest?definitionId=2&branchName=master) | - | - | + | Linux py3.{6,7} [TPUs***] | - | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | + | Linux py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | + | OSX py3.{6,7,8,9} | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | + | Windows py3.{6,7,8,9} | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=master&event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | -In fact, we also train a few models using a vanilla PyTorch loop and compare with the same model trained using the Trainer to make sure we achieve the EXACT same results. [Check out the parity tests here](https://github.com/PyTorchLightning/pytorch-lightning/tree/master/benchmarks). + - _\** tests run on two NVIDIA K80_ + - _\*** tests run on Google GKE TPUv2/3_ + - _TPU w/ py3.6/py3.7 means we support Colab and Kaggle env._ -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. +
+
-## How flexible is it? -As you see, you're just organizing your PyTorch code - there's no abstraction. +--- -And for the stuff that the Trainer abstracts out you can [override any part](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#extensibility) you want to do things like implement your own distributed training, 16-bit precision, or even a custom backwards pass. +## How To Use -For example, here you could do your own backward pass +### Step 0: Install -```python -class LitModel(LightningModule): - def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, - second_order_closure=None): - optimizer.step() - optimizer.zero_grad() +Simple installation from PyPI +```bash +pip install pytorch-lightning ``` -For anything else you might need, we have an extensive [callback system](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks) you can use to add arbitrary functionality not implemented by our team in the Trainer. + +
+ Other installation options + -## Who is Lightning for? -- Professional researchers -- PhD students -- Corporate production teams + #### Install with optional dependencies -If you're just getting into deep learning, we recommend you learn PyTorch first! Once you've implemented a few models, come back and use all the advanced features of Lightning :) + ```bash + pip install pytorch-lightning['extra'] + ``` -## What does lightning control for me? + #### Conda -Everything in Blue! -This is how lightning separates the science (red) from the engineering (blue). + ```bash + conda install pytorch-lightning -c conda-forge + ``` -![Overview](docs/source/_images/general/pl_overview.gif) + #### Install stable 1.2.x -## How much effort is it to convert? -If your code is not a huge mess you should be able to organize it into a LightningModule in less than 1 hour. -If your code IS a mess, then you needed to clean up anyhow ;) + the actual status of 1.2 [stable] is following: -[Check out this step-by-step guide](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09). + ![CI base testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20base%20testing/badge.svg?branch=release%2F1.2.x&event=push) + ![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?branch=release%2F1.2.x&event=push) + ![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg?branch=release%2F1.2.x&event=push) + ![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg?branch=release%2F1.2.x&event=push) + ![Docs check](https://github.com/PyTorchLightning/pytorch-lightning/workflows/Docs%20check/badge.svg?branch=release%2F1.2.x&event=push) + Install future release from the source + ```bash + pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.2.x --upgrade + ``` -## Starting a new project? -[Use our seed-project aimed at reproducibility!](https://github.com/PytorchLightning/pytorch-lightning-conference-seed) + #### Install bleeding-edge - future 1.3 -## Why do I want to use lightning? -Although your research/production project might start simple, once you add things like GPU AND TPU training, 16-bit precision, etc, you end up spending more time engineering than researching. Lightning automates AND rigorously tests those parts for you. + Install nightly from the source (no guarantees) + ```bash + pip install https://github.com/PyTorchLightning/pytorch-lightning/archive/master.zip + ``` -## Support -- [8 core contributors](https://pytorch-lightning.readthedocs.io/en/latest/governance.html) who are all a mix of professional engineers, Research Scientists, PhD students from top AI labs. -- 100+ community contributors. + or from testing PyPI + ```bash + pip install -iU https://test.pypi.org/simple/ pytorch-lightning + ``` +
+ -Lightning is also part of the [PyTorch ecosystem](https://pytorch.org/ecosystem/) which requires projects to have solid testing, documentation and support. +### Step 1: Add these imports ---- +```python +import os +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.datasets import MNIST +from torch.utils.data import DataLoader, random_split +from torchvision import transforms +import pytorch_lightning as pl +``` -## README Table of Contents -- [How do I use it](https://github.com/PytorchLightning/pytorch-lightning#how-do-i-do-use-it) -- [What lightning automates](https://github.com/PytorchLightning/pytorch-lightning#what-does-lightning-control-for-me) -- [Tensorboard integration](https://github.com/PytorchLightning/pytorch-lightning#tensorboard) -- [Lightning features](https://github.com/PytorchLightning/pytorch-lightning#lightning-automates-all-of-the-following-each-is-also-configurable) -- [Examples](https://github.com/PytorchLightning/pytorch-lightning#examples) -- [Tutorials](https://github.com/PytorchLightning/pytorch-lightning#tutorials) -- [Asking for help](https://github.com/PytorchLightning/pytorch-lightning#asking-for-help) -- [Contributing](https://github.com/PytorchLightning/pytorch-lightning/blob/master/.github/CONTRIBUTING.md) -- [Bleeding edge install](https://github.com/PytorchLightning/pytorch-lightning#bleeding-edge) -- [Lightning Design Principles](https://github.com/PytorchLightning/pytorch-lightning#lightning-design-principles) -- [Lightning team](https://github.com/PytorchLightning/pytorch-lightning#lightning-team) -- [FAQ](https://github.com/PytorchLightning/pytorch-lightning#faq) +### Step 2: Define a LightningModule (nn.Module subclass) +A LightningModule defines a full *system* (ie: a GAN, autoencoder, BERT or a simple Image Classifier). ---- +```python +class LitAutoEncoder(pl.LightningModule): + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) + self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + + def training_step(self, batch, batch_idx): + # training_step defined the train loop. It is independent of forward + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer +``` -## Realistic example -Here's how you would organize a realistic PyTorch project into Lightning. +**Note: Training_step defines the training loop. Forward defines how the LightningModule behaves during inference/prediction.** -![PT to PL](docs/source/_images/mnist_imgs/pt_to_pl.jpg) +### Step 3: Train! -The LightningModule defines a *system* such as seq-2-seq, GAN, etc... -It can ALSO define a simple classifier. +```python +dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) +train, val = random_split(dataset, [55000, 5000]) -In summary, you: +autoencoder = LitAutoEncoder() +trainer = pl.Trainer() +trainer.fit(autoencoder, DataLoader(train), DataLoader(val)) +``` -1. Define a [LightningModule](https://pytorch-lightning.rtfd.io/en/latest/lightning-module.html) -```python - class LitSystem(pl.LightningModule): +## Advanced features +Lightning has over [40+ advanced features](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags) designed for professional AI research at scale. - def __init__(self): - super().__init__() - # not the best model... - self.l1 = torch.nn.Linear(28 * 28, 10) +Here are some examples: - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) +
+ +
- def training_step(self, batch, batch_idx): - ... -``` +
+ Highlighted feature code snippets -2. Fit it with a [Trainer](https://pytorch-lightning.rtfd.io/en/latest/pytorch_lightning.trainer.html) - ```python - from pytorch_lightning import Trainer + ```python + # 8 GPUs + # no code changes needed + trainer = Trainer(max_epochs=1, gpus=8) - model = LitSystem() + # 256 GPUs + trainer = Trainer(max_epochs=1, gpus=8, num_nodes=32) + ``` - # most basic trainer, uses good defaults - trainer = Trainer() - trainer.fit(model) - ``` + Train on TPUs without code changes -[Check out the COLAB demo here](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg) + ```python + # no code changes needed + trainer = Trainer(tpu_cores=8) + ``` -## What types of research works? -Anything! Remember, that this is just organized PyTorch code. -The Training step defines the core complexity found in the training loop. + 16-bit precision -#### Could be as complex as a seq2seq + ```python + # no code changes needed + trainer = Trainer(precision=16) + ``` -```python -# define what happens for training here -def training_step(self, batch, batch_idx): - x, y = batch - - # define your own forward and loss calculation - hidden_states = self.encoder(x) - - # even as complex as a seq-2-seq + attn model - # (this is just a toy, non-working example to illustrate) - start_token = '' - last_hidden = torch.zeros(...) - loss = 0 - for step in range(max_seq_len): - attn_context = self.attention_nn(hidden_states, start_token) - pred = self.decoder(start_token, attn_context, last_hidden) - last_hidden = pred - pred = self.predict_nn(pred) - loss += self.loss(last_hidden, y[step]) - - #toy example as well - loss = loss / max_seq_len - return {'loss': loss} -``` + Experiment managers -#### Or as basic as CNN image classification + ```python + from pytorch_lightning import loggers -```python -# define what happens for validation here -def validation_step(self, batch, batch_idx): - x, y = batch - - # or as basic as a CNN classification - out = self(x) - loss = my_loss(out, y) - return {'loss': loss} -``` + # tensorboard + trainer = Trainer(logger=TensorBoardLogger('logs/')) -And without changing a single line of code, you could run on CPUs -```python -trainer = Trainer(max_epochs=1) -``` + # weights and biases + trainer = Trainer(logger=loggers.WandbLogger()) + # comet + trainer = Trainer(logger=loggers.CometLogger()) -Or GPUs -```python -# 8 GPUs -trainer = Trainer(max_epochs=1, gpus=8) + # mlflow + trainer = Trainer(logger=loggers.MLFlowLogger()) -# 256 GPUs -trainer = Trainer(max_epochs=1, gpus=8, num_nodes=32) -``` + # neptune + trainer = Trainer(logger=loggers.NeptuneLogger()) -Or TPUs -```python -trainer = Trainer(num_tpu_cores=8) -``` + # ... and dozens more + ``` -When you're done training, run the test accuracy -```python -trainer.test() -``` + EarlyStopping -## Visualization -Lightning has out-of-the-box integration with the popular logging/visualizing frameworks + ```python + es = EarlyStopping(monitor='val_loss') + trainer = Trainer(callbacks=[es]) + ``` -- [Tensorboard](https://pytorch.org/docs/stable/tensorboard.html) -- [MLFlow](https://mlflow.org/) -- [Neptune.ai](https://neptune.ai/) -- [Comet.ml](https://www.comet.ml/site/) -- [Wandb](https://www.wandb.com/) -- [Trains](https://github.com/allegroai/trains) -- ... + Checkpointing -![tensorboard-support](docs/source/_images/general/tf_loss.png) + ```python + checkpointing = ModelCheckpoint(monitor='val_loss') + trainer = Trainer(callbacks=[checkpointing]) + ``` + Export to torchscript (JIT) (production use) -## Lightning automates 40+ parts of DL/ML research -- GPU training -- Distributed GPU (cluster) training -- TPU training -- EarlyStopping -- Logging/Visualizing -- Checkpointing -- Experiment management -- [Full list here](https://pytorch-lightning.readthedocs.io/en/latest/#common-use-cases) + ```python + # torchscript + autoencoder = LitAutoEncoder() + torch.jit.save(autoencoder.to_torchscript(), "model.pt") + ``` + Export to ONNX (production use) -## Examples -Check out this awesome list of research papers and implementations done with Lightning. - -- [Contextual Emotion Detection (DoubleDistilBert)](https://github.com/PyTorchLightning/emotion_transformer) -- [Generative Adversarial Network](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=TyYOdg8g77P0) -- [Hyperparameter optimization with Optuna](https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py) -- [Image Inpainting using Partial Convolutions](https://github.com/ryanwongsa/Image-Inpainting) -- [MNIST on TPU](https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3#scrollTo=BHBz1_AnamN_) -- [NER (transformers, TPU, huggingface)](https://colab.research.google.com/drive/1dBN-wwYUngLYVt985wGs_OKPlK_ANB9D) -- [NeuralTexture (CVPR)](https://github.com/PyTorchLightning/neuraltexture) -- [Recurrent Attentive Neural Process](https://github.com/PyTorchLightning/attentive-neural-processes) -- [Siamese Nets for One-shot Image Recognition](https://github.com/PyTorchLightning/Siamese-Neural-Networks) -- [Speech Transformers](https://github.com/PyTorchLightning/speech-transformer-pytorch_lightning) -- [Transformers transfer learning (Huggingface)](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=yr7eaxkF-djf) -- [Transformers text classification](https://github.com/ricardorei/lightning-text-classification) -- [VAE Library of over 18+ VAE flavors](https://github.com/AntixK/PyTorch-VAE) - -## Tutorials -Check out our [introduction guide](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) to get started. -Or jump straight into [our tutorials](https://pytorch-lightning.readthedocs.io/en/latest/#tutorials). + ```python + # onnx + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: + autoencoder = LitAutoEncoder() + input_sample = torch.randn((1, 64)) + autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True) + os.path.isfile(tmpfile.name) + ``` +
+ +### Pro-level control of training loops (advanced users) +For complex/professional level work, you have optional full control of the training loop and optimizers. +```python +class LitAutoEncoder(pl.LightningModule): + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # access your optimizers with use_pl_optimizer=False. Default is True + opt_a, opt_b = self.optimizers(use_pl_optimizer=True) + + loss_a = ... + self.manual_backward(loss_a, opt_a) + opt_a.step() + opt_a.zero_grad() + + loss_b = ... + self.manual_backward(loss_b, opt_b, retain_graph=True) + self.manual_backward(loss_b, opt_b) + opt_b.step() + opt_b.zero_grad() +``` --- -## Asking for help -Welcome to the Lightning community! +## Advantages over unstructured PyTorch -If you have any questions, feel free to: -1. [read the docs](https://pytorch-lightning.rtfd.io/en/latest/). -2. [Search through the issues](https://github.com/PytorchLightning/pytorch-lightning/issues?utf8=%E2%9C%93&q=my++question). -3. [Ask on stackoverflow](https://stackoverflow.com/questions/ask?guided=false) with the tag pytorch-lightning. -4. [Join our slack](https://join.slack.com/t/pytorch-lightning/shared_invite/enQtODU5ODIyNTUzODQwLTFkMDg5Mzc1MDBmNjEzMDgxOTVmYTdhYjA1MDdmODUyOTg2OGQ1ZWZkYTQzODhhNzdhZDA3YmNhMDhlMDY4YzQ). +* Models become hardware agnostic +* Code is clear to read because engineering code is abstracted away +* Easier to reproduce +* Make fewer mistakes because lightning handles the tricky engineering +* Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate +* Lightning has dozens of integrations with popular machine learning tools. +* [Tested rigorously with every new PR](https://github.com/PyTorchLightning/pytorch-lightning/tree/master/tests). We test every combination of PyTorch and Python supported versions, every OS, multi GPUs and even TPUs. +* Minimal running speed overhead (about 300 ms per epoch compared with pure PyTorch). --- -## FAQ -**How do I use Lightning for rapid research?** -[Here's a walk-through](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) -**Why was Lightning created?** -Lightning has 3 goals in mind: +## Examples -1. Maximal flexibility while abstracting out the common boilerplate across research projects. -2. Reproducibility. If all projects use the LightningModule template, it will be much much easier to understand what's going on and where to look! It will also mean every implementation follows a standard format. -3. Democratizing PyTorch power user features. Distributed training? 16-bit? know you need them but don't want to take the time to implement? All good... these come built into Lightning. +###### Hello world +- [MNIST hello world](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb) +- [MNIST on TPUs](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb) -**How does Lightning compare with Ignite and fast.ai?** -[Here's a thorough comparison](https://medium.com/@_willfalcon/pytorch-lightning-vs-pytorch-ignite-vs-fast-ai-61dc7480ad8a). +###### Contrastive Learning +- [BYOL](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol) +- [CPC v2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#cpc-v2) +- [Moco v2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#moco-v2) +- [SIMCLR](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr) -**Is this another library I have to learn?** -Nope! We use pure Pytorch everywhere and don't add unnecessary abstractions! +###### NLP +- [BERT](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb) +- [GPT-2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2) -**Are there plans to support Python 2?** -Nope. -**Are there plans to support virtualenv?** -Nope. Please use anaconda or miniconda. +###### Reinforcement Learning +- [DQN](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dqn-models) +- [Dueling-DQN](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dueling-dqn) +- [Reinforce](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#reinforce) -**Which PyTorch versions do you support?** -- **PyTorch 1.1.0** - ```bash - # install pytorch 1.1.0 using the official instructions +###### Vision +- [GAN](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb) - # install test-tube 0.6.7.6 which supports 1.1.0 - pip install test-tube==0.6.7.6 +###### Classic ML +- [Logistic Regression](https://pytorch-lightning-bolts.readthedocs.io/en/latest/classic_ml.html#logistic-regression) +- [Linear Regression](https://pytorch-lightning-bolts.readthedocs.io/en/latest/classic_ml.html#linear-regression) - # install latest Lightning version without upgrading deps - pip install -U --no-deps pytorch-lightning - ``` -- **PyTorch 1.2.0, 1.3.0,** - Install via pip as normal +--- -## Custom installation +## Community -### Bleeding edge +The lightning community is maintained by +- [10+ core contributors](https://pytorch-lightning.readthedocs.io/en/latest/governance.html) who are all a mix of professional engineers, Research Scientists, and Ph.D. students from top AI labs. +- 400+ community contributors. -If you can't wait for the next release, install the most up to date code with: -* using GIT (locally clone whole repo with full history) - ```bash - pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade - ``` -* using instant zip (last state of the repo without git history) - ```bash - pip install https://github.com/PytorchLightning/pytorch-lightning/archive/master.zip --upgrade - ``` +Lightning is also part of the [PyTorch ecosystem](https://pytorch.org/ecosystem/) which requires projects to have solid testing, documentation and support. -### Any release installation +### Asking for help +If you have any questions please: +1. [Read the docs](https://pytorch-lightning.rtfd.io/en/latest). +2. [Search through existing Discussions](https://github.com/PyTorchLightning/pytorch-lightning/discussions), or [add a new question](https://github.com/PyTorchLightning/pytorch-lightning/discussions/new) +3. [Join our slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A). +### Funding +[We're venture funded](https://techcrunch.com/2020/10/08/grid-ai-raises-18-6m-series-a-to-help-ai-researchers-and-engineers-bring-their-models-to-production/) to make sure we can provide around the clock support, hire a full-time staff, attend conferences, and move faster through implementing features you request. -You can also install any past release `0.X.Y` from this repository: -```bash -pip install https://github.com/PytorchLightning/pytorch-lightning/archive/0.X.Y.zip --upgrade +--- + +## Grid AI +Grid AI is our native platform for training models at scale on the cloud! + +**Sign up for [early access here](https://www.grid.ai/)** + +To use grid, take your regular command: + +``` + python my_model.py --learning_rate 1e-6 --layers 2 --gpus 4 ``` -### Lightning team +And change it to use the grid train command: -#### Leads -- William Falcon [(williamFalcon)](https://github.com/williamFalcon) (Lightning founder) -- Jirka Borovec [(Borda)](https://github.com/Borda) (ghost :) -- Ethan Harris [(ethanwharris)](https://github.com/ethanwharris) (Torchbearer founder) -- Matthew Painter [(MattPainter01)](https://github.com/MattPainter01) (Torchbearer founder) -- Justus Schock [(justusschock)](https://github.com/justusschock) (Former Core Member PyTorch Ignite) +``` + grid train --grid_gpus 4 my_model.py --learning_rate 'uniform(1e-6, 1e-1, 20)' --layers '[2, 4, 8, 16]' +``` + +The above command will launch (20 * 4) experiments each running on 4 GPUs (320 GPUs!) - by making ZERO changes to +your code. + +--- -#### Core Maintainers +## Licence -- Nick Eggert [(neggert)](https://github.com/neggert) -- Jeff Ling [(jeffling)](https://github.com/jeffling) -- Jeremy Jordan [(jeremyjordan)](https://github.com/jeremyjordan) -- Tullie Murrell [(tullie)](https://github.com/tullie) -- Adrian Wälchli [(awaelchli)](https://github.com/awaelchli) +Please observe the Apache 2.0 license that is listed in this repository. In addition +the Lightning framework is Patent Pending. -## Bibtex -If you want to cite the framework feel free to use this (but only if you loved it 😊): +## BibTeX +If you want to cite the framework feel free to use this (but only if you loved it 😊) or [zendo](https://zenodo.org/record/3828935#.YC45Lc9Khqs): ```bibtex @article{falcon2019pytorch, title={PyTorch Lightning}, - author={Falcon, WA}, - journal={GitHub. Note: https://github. com/williamFalcon/pytorch-lightning Cited by}, + author={Falcon, WA and .al}, + journal={GitHub. Note: https://github.com/PyTorchLightning/pytorch-lightning}, volume={3}, year={2019} } diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 00000000000000..85664bac74b671 --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,127 @@ +# Python package +# Create and test a Python package on multiple Python versions. +# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more: +# https://docs.microsoft.com/azure/devops/pipelines/languages/python + +trigger: + tags: + include: + - '*' + branches: + include: + - master + - release/* + - refs/tags/* +pr: + - master + - release/* + +jobs: + - job: pytest + # how long to run the job before automatically cancelling + timeoutInMinutes: 45 + # how much time to give 'run always even if cancelled tasks' before stopping them + cancelTimeoutInMinutes: 2 + + pool: gridai-spot-pool + + #strategy: + # matrix: + # PT16: + # torch.version: '1.6' + # python.version: '3.7' + + # ToDo: this need to have installed docker in the base image... + #container: pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6 + #container: "pytorchlightning/pytorch_lightning:base-cuda-py$[ variables['python.version'] ]-torch1.6" + container: + # base ML image: mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04 + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6" + #endpoint: azureContainerRegistryConnection + options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" + + workspace: + clean: all + + steps: + + - bash: | + lspci | egrep 'VGA|3D' + whereis nvidia + nvidia-smi + python --version + pip --version + pip list + displayName: 'Image info & NVIDIA' + + - bash: | + export GIT_TERMINAL_PROMPT=1 + #sudo apt-get install -y cmake + # python -m pip install "pip==20.1" + pip install --requirement requirements.txt + python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)" + python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" + pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed + pip install git+https://$(AUTH_TOKEN)@github.com/PyTorchLightning/lightning-dtrun.git@v0.0.2 --no-cache-dir + pip list + displayName: 'Install dependencies' + + - bash: | + python tests/collect_env_details.py + python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" + displayName: 'Env details' + + - bash: | + wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/ + unzip -o legacy/checkpoints.zip -d legacy/ + ls -l legacy/checkpoints/ + displayName: 'Get legacy checkpoints' + + - bash: | + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 + displayName: 'Testing: standard' + + - bash: | + bash tests/special_tests.sh + displayName: 'Testing: special' + + - bash: | + python -m coverage report + python -m coverage xml + python -m coverage html + python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure + ls -l + displayName: 'Statistics' + + - task: PublishTestResults@2 + displayName: 'Publish test results' + inputs: + testResultsFiles: '$(Build.StagingDirectory)/test-results.xml' + testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)' + condition: succeededOrFailed() + + - task: PublishCodeCoverageResults@1 + displayName: 'Publish coverage report' + inputs: + codeCoverageTool: 'cobertura' + summaryFileLocation: 'coverage.xml' + reportDirectory: '$(Build.SourcesDirectory)/htmlcov' + testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)' + condition: succeededOrFailed() + + - bash: | + python -m pytest benchmarks -v --maxfail=2 --durations=0 + displayName: 'Testing: benchmarks' + + - script: | + set -e + python -m pytest pl_examples -v --maxfail=2 --durations=0 + pip install . --user --quiet + bash pl_examples/run_examples-args.sh --gpus 1 --max_epochs 1 --batch_size 64 --limit_train_batches 5 --limit_val_batches 3 + bash pl_examples/run_ddp-examples.sh --max_epochs 1 --batch_size 32 --limit_train_batches 2 --limit_val_batches 2 + # cd pl_examples/basic_examples + # bash submit_ddp_job.sh + # bash submit_ddp2_job.sh + env: + PL_USE_MOCKED_MNIST: "1" + displayName: 'Examples' diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index e69de29bb2d1d6..734288b07235d5 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +BENCHMARK_ROOT = os.path.dirname(__file__) +PROJECT_ROOT = os.path.dirname(BENCHMARK_ROOT) diff --git a/benchmarks/generate_comparison.py b/benchmarks/generate_comparison.py new file mode 100644 index 00000000000000..fa0e78eb3c742c --- /dev/null +++ b/benchmarks/generate_comparison.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import matplotlib.pylab as plt +import pandas as pd + +from benchmarks.test_basic_parity import measure_loops +from tests.helpers.advanced_models import ParityModuleMNIST, ParityModuleRNN + +NUM_EPOCHS = 20 +NUM_RUNS = 50 +MODEL_CLASSES = (ParityModuleRNN, ParityModuleMNIST) +PATH_HERE = os.path.dirname(__file__) +FIGURE_EXTENSION = '.png' + + +def _main(): + fig, axarr = plt.subplots(nrows=len(MODEL_CLASSES)) + + for i, cls_model in enumerate(MODEL_CLASSES): + path_csv = os.path.join(PATH_HERE, f'dump-times_{cls_model.__name__}.csv') + if os.path.isfile(path_csv): + df_time = pd.read_csv(path_csv, index_col=0) + else: + # todo: kind="Vanilla PT" -> use_lightning=False + vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS) + lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS) + + df_time = pd.DataFrame({'vanilla PT': vanilla['durations'][1:], 'PT Lightning': lightning['durations'][1:]}) + df_time /= NUM_RUNS + df_time.to_csv(os.path.join(PATH_HERE, f'dump-times_{cls_model.__name__}.csv')) + # todo: add also relative X-axis ticks to see both: relative and absolute time differences + df_time.plot.hist( + ax=axarr[i], + bins=20, + alpha=0.5, + title=cls_model.__name__, + legend=True, + grid=True, + ) + axarr[i].set(xlabel='time [seconds]') + + path_fig = os.path.join(PATH_HERE, f'figure-parity-times{FIGURE_EXTENSION}') + fig.tight_layout() + fig.savefig(path_fig) + + +if __name__ == '__main__': + _main() diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py new file mode 100644 index 00000000000000..53f303693ffdb9 --- /dev/null +++ b/benchmarks/test_basic_parity.py @@ -0,0 +1,177 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import time + +import numpy as np +import pytest +import torch +from tqdm import tqdm + +from pytorch_lightning import LightningModule, seed_everything, Trainer +from tests.helpers.advanced_models import ParityModuleMNIST, ParityModuleRNN + + +def assert_parity_relative(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.1): + # assert speeds + diffs = np.asarray(pl_values) - np.mean(pt_values) + # norm by vanilla time + diffs = diffs / norm_by + # relative to mean reference value + diffs = diffs / np.mean(pt_values) + assert np.mean(diffs) < max_diff, f"Lightning diff {diffs} was worse than vanilla PT (threshold {max_diff})" + + +def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.55): + # assert speeds + diffs = np.asarray(pl_values) - np.mean(pt_values) + # norm by event count + diffs = diffs / norm_by + assert np.mean(diffs) < max_diff, f"Lightning {diffs} was worse than vanilla PT (threshold {max_diff})" + + +# ParityModuleMNIST runs with num_workers=1 +@pytest.mark.parametrize( + 'cls_model,max_diff_speed,max_diff_memory', + [ + (ParityModuleRNN, 0.05, 0.0), + (ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr + ] +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_pytorch_parity( + tmpdir, + cls_model: LightningModule, + max_diff_speed: float, + max_diff_memory: float, + num_epochs: int = 4, + num_runs: int = 3, +): + """ + Verify that the same pytorch and lightning models achieve the same results + """ + lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs) + vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=num_epochs, num_runs=num_runs) + + # make sure the losses match exactly to 5 decimal places + print(f"Losses are for... \n vanilla: {vanilla['losses']} \n lightning: {lightning['losses']}") + for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']): + np.testing.assert_almost_equal(pl_out, pt_out, 5) + + # drop the first run for initialize dataset (download & filter) + assert_parity_absolute( + lightning['durations'][1:], vanilla['durations'][1:], norm_by=num_epochs, max_diff=max_diff_speed + ) + + assert_parity_relative(lightning['memory'], vanilla['memory'], max_diff=max_diff_memory) + + +def _hook_memory(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + used_memory = torch.cuda.max_memory_allocated() + else: + used_memory = np.nan + return used_memory + + +def measure_loops(cls_model, kind, num_runs=10, num_epochs=10): + """ + Returns an array with the last loss from each epoch for each run + """ + hist_losses = [] + hist_durations = [] + hist_memory = [] + + device_type = "cuda" if torch.cuda.is_available() else "cpu" + torch.backends.cudnn.deterministic = True + for i in tqdm(range(num_runs), desc=f'{kind} with {cls_model.__name__}'): + gc.collect() + if device_type == 'cuda': + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_cached() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_accumulated_memory_stats() + torch.cuda.reset_peak_memory_stats() + time.sleep(1) + + time_start = time.perf_counter() + + _loop = lightning_loop if kind == "PT Lightning" else vanilla_loop + final_loss, used_memory = _loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs) + + time_end = time.perf_counter() + + hist_losses.append(final_loss) + hist_durations.append(time_end - time_start) + hist_memory.append(used_memory) + + return { + 'losses': hist_losses, + 'durations': hist_durations, + 'memory': hist_memory, + } + + +def vanilla_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10): + device = torch.device(device_type) + # set seed + seed_everything(idx) + + # init model parts + model = cls_model() + dl = model.train_dataloader() + optimizer = model.configure_optimizers() + + # model to GPU + model = model.to(device) + + epoch_losses = [] + # as the first run is skipped, no need to run it long + for epoch in range(num_epochs if idx > 0 else 1): + + # run through full training set + for j, batch in enumerate(dl): + batch = [x.to(device) for x in batch] + loss_dict = model.training_step(batch, j) + loss = loss_dict['loss'] + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # track last epoch loss + epoch_losses.append(loss.item()) + + return epoch_losses[-1], _hook_memory() + + +def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10): + seed_everything(idx) + + model = cls_model() + # init model parts + trainer = Trainer( + # as the first run is skipped, no need to run it long + max_epochs=num_epochs if idx > 0 else 1, + progress_bar_refresh_rate=0, + weights_summary=None, + gpus=1 if device_type == 'cuda' else 0, + checkpoint_callback=False, + deterministic=True, + logger=False, + replace_sampler_ddp=False, + ) + trainer.fit(model) + + return trainer.train_loop.running_loss.last().item(), _hook_memory() diff --git a/benchmarks/test_rnn_parity.py b/benchmarks/test_rnn_parity.py deleted file mode 100644 index b1aa96088a8962..00000000000000 --- a/benchmarks/test_rnn_parity.py +++ /dev/null @@ -1,158 +0,0 @@ -import time - -import numpy as np -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader -import tests.base.utils as tutils - -from pytorch_lightning import Trainer, LightningModule - - -class AverageDataset(Dataset): - def __init__(self, dataset_len=300, sequence_len=100): - self.dataset_len = dataset_len - self.sequence_len = sequence_len - self.input_seq = torch.randn(dataset_len, sequence_len, 10) - top, bottom = self.input_seq.chunk(2, -1) - self.output_seq = top + bottom.roll(shifts=1, dims=-1) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, item): - return self.input_seq[item], self.output_seq[item] - - -class ParityRNN(LightningModule): - def __init__(self): - super(ParityRNN, self).__init__() - self.rnn = nn.LSTM(10, 20, batch_first=True) - self.linear_out = nn.Linear(in_features=20, out_features=5) - - def forward(self, x): - seq, last = self.rnn(x) - return self.linear_out(seq) - - def training_step(self, batch, batch_nb): - x, y = batch - y_hat = self(x) - loss = F.mse_loss(y_hat, y) - return {'loss': loss} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(AverageDataset(), batch_size=30) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_pytorch_parity(tmpdir): - """ - Verify that the same pytorch and lightning models achieve the same results - :param tmpdir: - :return: - """ - num_epochs = 2 - num_rums = 3 - - lightning_outs, pl_times = lightning_loop(ParityRNN, num_rums, num_epochs) - manual_outs, pt_times = vanilla_loop(ParityRNN, num_rums, num_epochs) - # make sure the losses match exactly to 5 decimal places - for pl_out, pt_out in zip(lightning_outs, manual_outs): - np.testing.assert_almost_equal(pl_out, pt_out, 8) - - tutils.assert_speed_parity(pl_times, pt_times, num_epochs) - - -def set_seed(seed): - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - - -def vanilla_loop(MODEL, num_runs=10, num_epochs=10): - """ - Returns an array with the last loss from each epoch for each run - """ - device = torch.device('cuda' if torch.cuda.is_available() else "cpu") - errors = [] - times = [] - - for i in range(num_runs): - time_start = time.perf_counter() - - # set seed - seed = i - set_seed(seed) - - # init model parts - model = MODEL() - dl = model.train_dataloader() - optimizer = model.configure_optimizers() - - # model to GPU - model = model.to(device) - - epoch_losses = [] - for epoch in range(num_epochs): - - # run through full training set - for j, batch in enumerate(dl): - x, y = batch - x = x.cuda(0) - y = y.cuda(0) - batch = (x, y) - - loss_dict = model.training_step(batch, j) - loss = loss_dict['loss'] - loss.backward() - optimizer.step() - optimizer.zero_grad() - - # track last epoch loss - epoch_losses.append(loss.item()) - - time_end = time.perf_counter() - times.append(time_end - time_start) - - errors.append(epoch_losses[-1]) - - return errors, times - - -def lightning_loop(MODEL, num_runs=10, num_epochs=10): - errors = [] - times = [] - - for i in range(num_runs): - time_start = time.perf_counter() - - # set seed - seed = i - set_seed(seed) - - # init model parts - model = MODEL() - trainer = Trainer( - max_epochs=num_epochs, - progress_bar_refresh_rate=0, - weights_summary=None, - gpus=1, - early_stop_callback=False, - checkpoint_callback=False, - distributed_backend='dp', - ) - trainer.fit(model) - - final_loss = trainer.running_loss.last().item() - errors.append(final_loss) - - time_end = time.perf_counter() - times.append(time_end - time_start) - - return errors, times diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py new file mode 100644 index 00000000000000..28cbd7828b1087 --- /dev/null +++ b/benchmarks/test_sharded_parity.py @@ -0,0 +1,225 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Type + +import pytest +import torch + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.plugins import DDPSpawnShardedPlugin +from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.runif import RunIf + + +class SeedTrainLoaderModel(BoringModel): + """ + Overrides training loader to ensure we enforce the same seed for all DDP processes. + """ + + def train_dataloader(self): + seed_everything(42) + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +class SeedTrainLoaderManualModel(SeedTrainLoaderModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + # access your optimizers with use_pl_optimizer=False. Default is True + (opt_a, opt_b) = self.optimizers(use_pl_optimizer=True) + loss_1 = self.step(batch) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b) + # todo: understand why synchronization breaks there. + # self.manual_backward(loss_2, opt_a, retain_graph=True) + opt_b.step() + + assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + @property + def automatic_optimization(self) -> bool: + return False + + +class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {'loss': loss} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + +def record_ddp_fit_model_stats(trainer, model, use_cuda): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The model to fit. + use_cuda: Whether to sync CUDA kernels. + + Returns: + Max Memory if using GPUs, and total wall clock time. + """ + max_memory = None + + time_start = time.perf_counter() + if use_cuda: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if use_cuda: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2**20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def plugin_parity_test( + model_cls: Type[SeedTrainLoaderModel], + seed: int = 42, + gpus: int = 0, + precision: int = 32, + max_percent_speed_diff: float = 0.1, +): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + model_cls: Model class to use for test. + seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + + """ + + # Train normal DDP + seed_everything(seed) + ddp_model = model_cls() + use_cuda = gpus > 0 + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator='ddp_spawn', + ) + + max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda) + + # Reset and train Custom DDP + seed_everything(seed) + custom_plugin_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator='ddp_sharded_spawn', + ) + assert isinstance(trainer.training_type_plugin, DDPSpawnShardedPlugin) + + max_memory_custom, custom_model_time = record_ddp_fit_model_stats( + trainer=trainer, model=custom_plugin_model, use_cuda=use_cuda + ) + + # Assert model parameters are identical after fit + for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()): + assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin' + + # Assert speed parity by ensuring percentage difference between custom/ddp is below threshold + percent_diff = (custom_model_time - ddp_time) / custom_model_time + + assert ( + percent_diff <= max_percent_speed_diff + ), f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}' + + if use_cuda: + # Assert CUDA memory parity + assert max_memory_custom <= max_memory_ddp, ( + 'Custom plugin used too much memory compared to DDP, ' + f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}' + ) + + +@RunIf(skip_windows=True, fairscale=True) +@pytest.mark.parametrize( + 'kwargs', + [ + pytest.param(dict(gpus=1, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=1)), + pytest.param( + dict(gpus=1, precision=16, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=1, amp_native=True) + ), + pytest.param(dict(gpus=2, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=2)), + pytest.param( + dict(gpus=2, precision=16, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=2, amp_native=True) + ), + pytest.param( + dict(gpus=2, model_cls=SeedTrainLoaderMultipleOptimizersModel), + marks=[ + RunIf(min_gpus=2), + pytest.mark.skip(reason='TODO: Current issue with multiple optimizers and FairScale.'), + ], + ), + pytest.param( + dict(gpus=2, model_cls=SeedTrainLoaderManualModel), + marks=[ + RunIf(min_gpus=2), + pytest.mark.skip(reason='TODO: Current issue with multiple optimizers and FairScale.'), + ], + ), + ], +) +def test_ddp_spawn_sharded_plugin(kwargs): + if kwargs['gpus'] > 1: + # TODO: decrease speed diff since only 2 GPUs sharding 2 optimizers + kwargs['max_percent_speed_diff'] = 0.25 + plugin_parity_test(**kwargs) diff --git a/benchmarks/test_trainer_parity.py b/benchmarks/test_trainer_parity.py deleted file mode 100644 index 4c0e89d107b7d9..00000000000000 --- a/benchmarks/test_trainer_parity.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import time - -import numpy as np -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader -from torchvision import transforms -import tests.base.utils as tutils - -from pytorch_lightning import Trainer, LightningModule -from tests.base.datasets import TrialMNIST - - -class ParityMNIST(LightningModule): - - def __init__(self): - super(ParityMNIST, self).__init__() - self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) - self.c_d1_bn = nn.BatchNorm1d(128) - self.c_d1_drop = nn.Dropout(0.3) - self.c_d2 = nn.Linear(in_features=128, out_features=10) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = self.c_d1(x) - x = torch.tanh(x) - x = self.c_d1_bn(x) - x = self.c_d1_drop(x) - x = self.c_d2(x) - return x - - def training_step(self, batch, batch_nb): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - return {'loss': loss} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, - download=True, - num_samples=500, - digits=list(range(5))), - batch_size=128) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_pytorch_parity(tmpdir): - """ - Verify that the same pytorch and lightning models achieve the same results - :param tmpdir: - :return: - """ - num_epochs = 2 - num_rums = 3 - lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs) - manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs) - - # make sure the losses match exactly to 5 decimal places - for pl_out, pt_out in zip(lightning_outs, manual_outs): - np.testing.assert_almost_equal(pl_out, pt_out, 5) - - # the fist run initialize dataset (download & filter) - tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs) - - -def _set_seed(seed): - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - - -def vanilla_loop(MODEL, num_runs=10, num_epochs=10): - """ - Returns an array with the last loss from each epoch for each run - """ - device = torch.device('cuda' if torch.cuda.is_available() else "cpu") - errors = [] - times = [] - - for i in range(num_runs): - time_start = time.perf_counter() - - # set seed - seed = i - _set_seed(seed) - - # init model parts - model = MODEL() - dl = model.train_dataloader() - optimizer = model.configure_optimizers() - - # model to GPU - model = model.to(device) - - epoch_losses = [] - for epoch in range(num_epochs): - - # run through full training set - for j, batch in enumerate(dl): - x, y = batch - x = x.cuda(0) - y = y.cuda(0) - batch = (x, y) - - loss_dict = model.training_step(batch, j) - loss = loss_dict['loss'] - loss.backward() - optimizer.step() - optimizer.zero_grad() - - # track last epoch loss - epoch_losses.append(loss.item()) - - time_end = time.perf_counter() - times.append(time_end - time_start) - - errors.append(epoch_losses[-1]) - - return errors, times - - -def lightning_loop(MODEL, num_runs=10, num_epochs=10): - errors = [] - times = [] - - for i in range(num_runs): - time_start = time.perf_counter() - - # set seed - seed = i - _set_seed(seed) - - # init model parts - model = MODEL() - trainer = Trainer( - max_epochs=num_epochs, - progress_bar_refresh_rate=0, - weights_summary=None, - gpus=1, - early_stop_callback=False, - checkpoint_callback=False - ) - trainer.fit(model) - - final_loss = trainer.running_loss.last().item() - errors.append(final_loss) - - time_end = time.perf_counter() - times.append(time_end - time_start) - - return errors, times diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index f6720e1f8e87a3..00000000000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,42 +0,0 @@ -ARG CUDA_VERSION=10.1 -FROM nvidia/cuda:${CUDA_VERSION}-base - -# install versions -ARG PYTHON_VERSION=3.7 -ARG PYTORCH_VERSION=1.4 -ARG LIGHTNING_VERSION=master - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - curl \ - ca-certificates - -# add non-root user -RUN useradd --create-home --shell /bin/bash containeruser -USER containeruser -WORKDIR /home/containeruser - - -# install conda and python -RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /home/containeruser/conda && \ - rm ~/miniconda.sh && \ - /home/containeruser/conda/bin/conda clean -ya && \ - /home/containeruser/conda/bin/conda install -y python=$PYTHON_VERSION - -# add conda to path -ENV PATH /home/containeruser/conda/bin:$PATH - -# install dependencies -RUN pip install torch==$PYTORCH_VERSION -RUN git clone https://github.com/PyTorchLightning/pytorch-lightning.git --single-branch --branch $LIGHTNING_VERSION && \ - pip install ./pytorch-lightning && \ - pip install -r pytorch-lightning/requirements-extra.txt && \ - rm -rf pytorch-lightning - -RUN python -c "import pytorch_lightning as pl; print(pl.__version__)" - -CMD ["/bin/bash"] diff --git a/docker/README.md b/docker/README.md deleted file mode 100644 index 9844cd5b66e267..00000000000000 --- a/docker/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## Builds - -You can build it on your own, note it takes lots of time, be prepared. -```bash -git clone -docker image build -t pytorch-lightning:py36 -f docker/Dockerfile --build-arg PYTHON_VERSION=3.6 . -``` -To build other versions, select different Dockerfile. -```bash -docker image list -docker run --rm -it pytorch-lightning:py36 bash -docker image rm pytorch-lightning:py36 -``` \ No newline at end of file diff --git a/dockers/README.md b/dockers/README.md new file mode 100644 index 00000000000000..89647ac443dde7 --- /dev/null +++ b/dockers/README.md @@ -0,0 +1,65 @@ +# Docker images + +## Builds images form attached Dockerfiles + +You can build it on your own, note it takes lots of time, be prepared. + +```bash +git clone +docker image build -t pytorch-lightning:latest -f dockers/conda/Dockerfile . +``` + +or with specific arguments + +```bash +git clone +docker image build \ + -t pytorch-lightning:py3.8-pt1.6 \ + -f dockers/base-cuda/Dockerfile \ + --build-arg PYTHON_VERSION=3.8 \ + --build-arg PYTORCH_VERSION=1.6 \ + . +``` +or nightly version from Coda +```bash +git clone +docker image build \ + -t pytorch-lightning:py3.7-pt1.8 \ + -f dockers/base-conda/Dockerfile \ + --build-arg PYTHON_VERSION=3.7 \ + --build-arg PYTORCH_VERSION=1.8 \ + . +``` + +To run your docker use + +```bash +docker image list +docker run --rm -it pytorch-lightning:latest bash +``` + +and if you do not need it anymore, just clean it: + +```bash +docker image list +docker image rm pytorch-lightning:latest +``` + +### Run docker image with GPUs + +To run docker image with access to you GPUs you need to install +```bash +# Add the package repositories +distribution=$(. /etc/os-release;echo $ID$VERSION_ID) +curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - +curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list + +sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit +sudo systemctl restart docker +``` + +and later run the docker image with `--gpus all` so for example + +``` +docker run --rm -it --gpus all pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6 +``` diff --git a/dockers/base-conda/Dockerfile b/dockers/base-conda/Dockerfile new file mode 100644 index 00000000000000..585aa1768ffd77 --- /dev/null +++ b/dockers/base-conda/Dockerfile @@ -0,0 +1,128 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Existing images: +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.8 +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.6 +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.5 +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.4 + +ARG CUDNN_VERSION=8 +ARG CUDA_VERSION=10.2 + +# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu18.04 +# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu18.04 + +ARG PYTHON_VERSION=3.7 +ARG PYTORCH_VERSION=1.6 +ARG CONDA_VERSION=4.9.2 + +SHELL ["/bin/bash", "-c"] + +ENV PATH="$PATH:/root/.local/bin" + +RUN apt-get update -qq && \ + apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + git \ + wget \ + curl \ + unzip \ + ca-certificates \ + libopenmpi-dev \ + && \ + +# Install conda and python. +# NOTE new Conda does not forward the exit status... https://github.com/conda/conda/issues/8385 + curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_${CONDA_VERSION}-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b && \ + rm ~/miniconda.sh && \ + +# Cleaning + apt-get autoremove -y && \ + apt-get clean && \ + rm -rf /root/.cache && \ + rm -rf /var/lib/apt/lists/* + +ENV \ + PATH="/root/miniconda3/bin:$PATH" \ + LD_LIBRARY_PATH="/root/miniconda3/lib:$LD_LIBRARY_PATH" \ + CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \ + MKL_THREADING_LAYER=GNU \ + HOROVOD_GPU_OPERATIONS=NCCL \ + HOROVOD_WITH_PYTORCH=1 \ + HOROVOD_WITHOUT_TENSORFLOW=1 \ + HOROVOD_WITHOUT_MXNET=1 \ + HOROVOD_WITH_GLOO=1 \ + HOROVOD_WITHOUT_MPI=1 \ + # MAKEFLAGS="-j$(nproc)" \ + MAKEFLAGS="-j1" \ + TORCH_CUDA_ARCH_LIST="3.7;5.0;6.0;7.0;7.5" \ + CONDA_ENV=lightning + +COPY environment.yml environment.yml + +# conda init +RUN conda create -y --name $CONDA_ENV python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION} cudatoolkit=${CUDA_VERSION} -c pytorch -c pytorch-test -c pytorch-nightly && \ + conda init bash && \ + # NOTE: this requires that the channel is presented in the yaml before packages + # replace channel to nigtly if needed, fix PT version and remove Horovod as it will be installed later + python -c "import re ; fname = 'environment.yml' ; req = re.sub(r'- python[>=]+[\d\.]+', '# - python=${PYTHON_VERSION}', open(fname).read()) ; open(fname, 'w').write(req)" && \ + python -c "import re ; fname = 'environment.yml' ; req = re.sub(r'- pytorch[>=]+[\d\.]+', '# - pytorch=${PYTORCH_VERSION}', open(fname).read()) ; open(fname, 'w').write(req)" && \ + python -c "import re ; fname = 'environment.yml' ; req = re.sub(r'- horovod[>=]+[\d\.]+', '# - horovod', open(fname).read()) ; open(fname, 'w').write(req)" && \ + python -c "fname = 'environment.yml' ; req = open(fname).readlines() ; open(fname, 'w').writelines([ln for ln in req if 'horovod' not in ln])" && \ + cat environment.yml && \ + conda env update --name $CONDA_ENV --file environment.yml && \ + conda clean -ya && \ + rm environment.yml + +ENV \ + PATH /root/miniconda3/envs/${CONDA_ENV}/bin:$PATH \ + LD_LIBRARY_PATH="/root/miniconda3/envs/${CONDA_ENV}/lib:$LD_LIBRARY_PATH" \ + # if you want this environment to be the default one, uncomment the following line: + CONDA_DEFAULT_ENV=${CONDA_ENV} + +COPY ./requirements/extra.txt requirements-extra.txt +COPY ./requirements/test.txt requirements-test.txt +COPY ./requirements/adjust_versions.py requirements_adjust_versions.py + +RUN \ + pip list | grep torch && \ + python -c "import torch; print(torch.__version__)" && \ + python requirements_adjust_versions.py requirements-extra.txt && \ + # Install remaining requirements + pip install -r requirements-extra.txt --no-cache-dir && \ + pip install -r requirements-test.txt --no-cache-dir && \ + rm requirements* + +RUN \ + # install DALI, needed for examples + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda${CUDA_VERSION%%.*}0 + +RUN \ + # install NVIDIA AMP + git clone https://github.com/NVIDIA/apex && \ + pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \ + rm -rf apex + +RUN \ + # Show what we have + pip --version && \ + conda info && \ + pip list && \ + python -c "import sys; assert sys.version[:3] == '$PYTHON_VERSION', sys.version" && \ + python -c "import torch; assert torch.__version__[:3] == '$PYTORCH_VERSION', torch.__version__" diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile new file mode 100644 index 00000000000000..843e47ca912894 --- /dev/null +++ b/dockers/base-cuda/Dockerfile @@ -0,0 +1,127 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Existing images: +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.7 --build-arg CUDA_VERSION=10.2 +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.6 --build-arg CUDA_VERSION=10.2 +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.5 --build-arg CUDA_VERSION=10.2 +# --build-arg PYTHON_VERSION=3.7 --build-arg PYTORCH_VERSION=1.4 --build-arg CUDA_VERSION=10.1 + +ARG CUDNN_VERSION=8 +ARG CUDA_VERSION=10.2 + +# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +FROM nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu18.04 +# FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu18.04 + +ARG PYTHON_VERSION=3.7 +ARG PYTORCH_VERSION=1.6 +ARG CMAKE_VERSION=3.18.4 + +SHELL ["/bin/bash", "-c"] +# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/ +ENV \ + DEBIAN_FRONTEND=noninteractive \ + TZ=Europe/Prague \ + PATH="$PATH:/root/.local/bin" \ + CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \ + MKL_THREADING_LAYER=GNU + +RUN apt-get update -qq && \ + apt-get install -y --no-install-recommends \ + build-essential \ + pkg-config \ + cmake \ + git \ + wget \ + curl \ + unzip \ + ca-certificates \ + software-properties-common \ + libopenmpi-dev \ + && \ + +# Install python + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get install -y \ + python${PYTHON_VERSION} \ + python${PYTHON_VERSION}-distutils \ + python${PYTHON_VERSION}-dev \ + && \ + + update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \ + update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 && \ + +# Cleaning + apt-get autoremove -y && \ + apt-get clean && \ + rm -rf /root/.cache && \ + rm -rf /var/lib/apt/lists/* + +ENV \ + HOROVOD_GPU_OPERATIONS=NCCL \ + HOROVOD_WITH_PYTORCH=1 \ + HOROVOD_WITHOUT_TENSORFLOW=1 \ + HOROVOD_WITHOUT_MXNET=1 \ + HOROVOD_WITH_GLOO=1 \ + HOROVOD_WITHOUT_MPI=1 \ + # MAKEFLAGS="-j$(nproc)" \ + MAKEFLAGS="-j1" \ + TORCH_CUDA_ARCH_LIST="3.7;5.0;6.0;7.0;7.5" + +COPY ./requirements.txt requirements.txt +COPY ./requirements/ ./requirements/ + +# conda init +RUN \ + wget https://bootstrap.pypa.io/get-pip.py --progress=bar:force:noscroll --no-check-certificate && \ + python${PYTHON_VERSION} get-pip.py && \ + rm get-pip.py && \ + + # Disable cache + pip config set global.cache-dir false && \ + # eventualy use pre-release + #pip install "torch==${PYTORCH_VERSION}.*" --pre && \ + # set particular PyTorch version + python ./requirements/adjust_versions.py requirements.txt ${PYTORCH_VERSION} && \ + python ./requirements/adjust_versions.py requirements/extra.txt ${PYTORCH_VERSION} && \ + python ./requirements/adjust_versions.py requirements/examples.txt ${PYTORCH_VERSION} && \ + # Install all requirements + # todo: find a way how to install nightly PT version + # --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${cuda_ver[0]}${cuda_ver[1]}/torch_nightly.html + pip install -r requirements/devel.txt --no-cache-dir && \ + rm -rf requirements.* requirements/ + +RUN \ + # install DALI, needed for examples + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda${CUDA_VERSION%%.*}0 + +RUN \ + # install NVIDIA AMP + git clone https://github.com/NVIDIA/apex && \ + pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \ + rm -rf apex + +RUN \ + # install DeepSpeed from source. + # todo: swap to pypi release once DeepSpeed releases a new version >= 0.3.10 + pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb + +RUN \ + # Show what we have + pip --version && \ + pip list && \ + python -c 'from nvidia.dali.pipeline import Pipeline' && \ + python -c "import sys; assert sys.version[:3] == '$PYTHON_VERSION', sys.version" && \ + python -c "import torch; assert torch.__version__[:3] == '$PYTORCH_VERSION', torch.__version__" diff --git a/dockers/base-xla/Dockerfile b/dockers/base-xla/Dockerfile new file mode 100644 index 00000000000000..7f7e74bba75a60 --- /dev/null +++ b/dockers/base-xla/Dockerfile @@ -0,0 +1,119 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM google/cloud-sdk:slim + +MAINTAINER PyTorchLightning + +# CALL: docker image build -t pytorch-lightning:XLA-extras-py3.6 -f dockers/base-xla/Dockerfile . --build-arg PYTHON_VERSION=3.6 +# This Dockerfile installs pytorch/xla 3.7 wheels. There are also 3.6 wheels available; see below. +ARG PYTHON_VERSION=3.7 +ARG XLA_VERSION=1.6 + +SHELL ["/bin/bash", "-c"] + +ARG CONDA_VERSION=4.9.2 +# for skipping configurations +ENV \ + DEBIAN_FRONTEND=noninteractive \ + CONDA_ENV=lightning + +# show system inforation +RUN lsb_release -a && cat /etc/*-release + +RUN apt-get update -qq && \ + apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + wget \ + curl \ + unzip \ + ca-certificates \ + libomp5 \ + && \ +# Install conda and python. +# NOTE new Conda does not forward the exit status... https://github.com/conda/conda/issues/8385 + curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_${CONDA_VERSION}-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b && \ + rm ~/miniconda.sh && \ +# Cleaning + apt-get autoremove -y && \ + apt-get clean && \ + rm -rf /root/.cache && \ + rm -rf /var/lib/apt/lists/* + +ENV \ + PATH="/root/miniconda3/bin:$PATH" \ + LD_LIBRARY_PATH="/root/miniconda3/lib:$LD_LIBRARY_PATH" +COPY environment.yml environment.yml + +RUN conda create -y --name $CONDA_ENV && \ + conda init bash && \ + # replace channel to nigtly if neede, fix PT version and remove Horovod as it will be installe later + python -c "import re ; fname = 'environment.yml' ; req = re.sub(r'python>=[\d\.]+', 'python=${PYTHON_VERSION}', open(fname).read()) ; open(fname, 'w').write(req)" && \ + python -c "fname = 'environment.yml' ; req = open(fname).readlines() ; open(fname, 'w').writelines([ln for ln in req if not any(n in ln for n in ['pytorch>', 'horovod'])])" && \ + cat environment.yml && \ + conda env update --file environment.yml && \ + conda clean -ya && \ + rm environment.yml + +ENV \ + PATH=/root/miniconda3/envs/${CONDA_ENV}/bin:$PATH \ + LD_LIBRARY_PATH="/root/miniconda3/envs/${CONDA_ENV}/lib:$LD_LIBRARY_PATH" \ + # if you want this environment to be the default one, uncomment the following line: + CONDA_DEFAULT_ENV=${CONDA_ENV} + +# Disable cache +RUN pip --version && \ + pip config set global.cache-dir false && \ + conda remove pytorch torchvision && \ +# Install Pytorch XLA + py_version=${PYTHON_VERSION/./} && \ + # Python 3.7 wheels are available. Replace cp36-cp36m with cp37-cp37m + gsutil cp "gs://tpu-pytorch/wheels/torch-${XLA_VERSION}-cp${py_version}-cp${py_version}m-linux_x86_64.whl" . && \ + gsutil cp "gs://tpu-pytorch/wheels/torch_xla-${XLA_VERSION}-cp${py_version}-cp${py_version}m-linux_x86_64.whl" . && \ + gsutil cp "gs://tpu-pytorch/wheels/torchvision-${XLA_VERSION}-cp${py_version}-cp${py_version}m-linux_x86_64.whl" . && \ + pip install *.whl && \ + rm *.whl + +# Get package +COPY ./ ./pytorch-lightning/ + +# Install pytorch-lightning dependencies. +RUN \ + python --version && \ +# Install PL dependencies + cd pytorch-lightning && \ + # drop Torch as it was installed with XLA + python -c "fname = 'requirements.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('torch')] ; open(fname, 'w').writelines(lines)" && \ + # drop Horovod as it is not needed + python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" && \ + # drop fairscale as it is not needed + python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)" && \ + # drop TorchVision as it was installed with XLA + python -c "fname = 'requirements/examples.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('torchvision')] ; open(fname, 'w').writelines(lines)" && \ + python ./requirements/adjust_versions.py ./requirements/extra.txt && \ + pip install --requirement ./requirements/devel.txt --no-cache-dir && \ + cd .. && \ + rm -rf pytorch-lightning && \ + rm -rf /root/.cache + +RUN \ + # Show what we have + pip --version && \ + conda info && \ + pip list && \ + python -c "import sys; assert sys.version[:3] == '$PYTHON_VERSION', sys.version" && \ + python -c "import torch; ver = '$XLA_VERSION' ; ver = dict(nightly='1.9').get(ver, ver) ; assert torch.__version__[:3] == ver, torch.__version__" diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile new file mode 100644 index 00000000000000..ad1169c4450dd0 --- /dev/null +++ b/dockers/nvidia/Dockerfile @@ -0,0 +1,84 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM nvcr.io/nvidia/cuda:11.1.1-runtime-ubuntu20.04 + +MAINTAINER PyTorchLightning + +ARG LIGHTNING_VERSION="" + +SHELL ["/bin/bash", "-c"] +# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/ +ENV \ + DEBIAN_FRONTEND=noninteractive \ + TZ=Europe/Prague \ + PATH="$PATH:/root/.local/bin" \ + CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \ + MKL_THREADING_LAYER=GNU + +RUN apt-get update -qq && \ + apt-get install -y --no-install-recommends \ + build-essential \ + python3 \ + python3-distutils \ + python3-dev \ + pkg-config \ + cmake \ + git \ + wget \ + unzip \ + ca-certificates \ + && \ + +# Cleaning + apt-get autoremove -y && \ + apt-get clean && \ + rm -rf /root/.cache && \ + rm -rf /var/lib/apt/lists/* && \ + +# Setup PIP + update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \ + wget https://bootstrap.pypa.io/get-pip.py --progress=bar:force:noscroll --no-check-certificate && \ + python get-pip.py && \ + rm get-pip.py && \ + pip --version + +COPY ./ /home/pytorch-lightning/ + +RUN \ + cd /home && \ + mv pytorch-lightning/notebooks . && \ + mv pytorch-lightning/pl_examples . && \ + # replace by specific version if asked + if [ ! -z "$LIGHTNING_VERSION" ] ; then \ + rm -rf pytorch-lightning ; \ + wget https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --progress=bar:force:noscroll ; \ + unzip ${LIGHTNING_VERSION}.zip ; \ + mv pytorch-lightning-*/ pytorch-lightning ; \ + rm *.zip ; \ + fi && \ + +# Installations + python -c "fname = './pytorch-lightning/requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" && \ + pip install -r ./pytorch-lightning/requirements/extra.txt -U --no-cache-dir && \ + pip install -r ./pytorch-lightning/requirements/examples.txt -U --no-cache-dir && \ + pip install ./pytorch-lightning --no-cache-dir && \ + rm -rf pytorch-lightning + +RUN python --version && \ + pip --version && \ + pip list && \ + python -c "import pytorch_lightning as pl; print(pl.__version__)" + +# CMD ["/bin/bash"] diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile new file mode 100644 index 00000000000000..5cd53385f660bb --- /dev/null +++ b/dockers/release/Dockerfile @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG PYTHON_VERSION=3.7 +ARG PYTORCH_VERSION=1.5 + +FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION} + +MAINTAINER PyTorchLightning + +ARG LIGHTNING_VERSION="" + +COPY ./ /home/pytorch-lightning/ + +# install dependencies +RUN \ + cd /home && \ + mv pytorch-lightning/notebooks . && \ + mv pytorch-lightning/pl_examples . && \ + # replace by specific version if asked + if [ ! -z "$LIGHTNING_VERSION" ] ; then \ + rm -rf pytorch-lightning ; \ + wget https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --progress=bar:force:noscroll ; \ + unzip ${LIGHTNING_VERSION}.zip ; \ + mv pytorch-lightning-*/ pytorch-lightning ; \ + rm *.zip ; \ + fi && \ + pip install ./pytorch-lightning["extra"] --no-cache-dir && \ + rm -rf pytorch-lightning + +RUN python --version && \ + pip --version && \ + pip list && \ + python -c "import pytorch_lightning as pl; print(pl.__version__)" + +# CMD ["/bin/bash"] diff --git a/dockers/tpu-tests/Dockerfile b/dockers/tpu-tests/Dockerfile new file mode 100644 index 00000000000000..93d6244121891d --- /dev/null +++ b/dockers/tpu-tests/Dockerfile @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG PYTHON_VERSION=3.7 +ARG PYTORCH_VERSION=1.6 + +FROM pytorchlightning/pytorch_lightning:base-xla-py${PYTHON_VERSION}-torch${PYTORCH_VERSION} + +MAINTAINER PyTorchLightning + +#SHELL ["/bin/bash", "-c"] + +COPY ./ ./pytorch-lightning/ + +# Pull the legacy checkpoints +RUN cd pytorch-lightning && \ + wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/ && \ + unzip -o legacy/checkpoints.zip -d legacy/ && \ + ls -l legacy/checkpoints/ + +# If using this image for tests, intall more dependencies and don"t delete the source code where the tests live. +RUN \ + # Install pytorch-lightning at the current PR, plus dependencies. + #pip install -r pytorch-lightning/requirements.txt --no-cache-dir && \ + # drop Horovod as it is not needed + python -c "fname = 'pytorch-lightning/requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" && \ + # drop fairscale as it is not needed + python -c "fname = 'pytorch-lightning/requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)" && \ + pip install -r pytorch-lightning/requirements/devel.txt --no-cache-dir + +#RUN python -c "import pytorch_lightning as pl; print(pl.__version__)" + +COPY ./dockers/tpu-tests/docker-entrypoint.sh /usr/local/bin/ +RUN chmod +x /usr/local/bin/docker-entrypoint.sh + +ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] +CMD ["bash"] diff --git a/dockers/tpu-tests/docker-entrypoint.sh b/dockers/tpu-tests/docker-entrypoint.sh new file mode 100644 index 00000000000000..57abc703c8aceb --- /dev/null +++ b/dockers/tpu-tests/docker-entrypoint.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# source ~/.bashrc +echo "running docker-entrypoint.sh" +# conda activate container +echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS +echo "printed TPU info" +export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" +exec "$@" diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet new file mode 100644 index 00000000000000..8c3f3693fda50a --- /dev/null +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -0,0 +1,37 @@ +local base = import 'templates/base.libsonnet'; +local tpus = import 'templates/tpus.libsonnet'; +local utils = import "templates/utils.libsonnet"; + +local tputests = base.BaseTest { + frameworkPrefix: 'pl', + modelName: 'tpu-tests', + mode: 'postsubmit', + configMaps: [], + + timeout: 900, # 15 minutes, in seconds. + + image: std.extVar('image'), + imageTag: std.extVar('image-tag'), + + tpuSettings+: { + softwareVersion: 'pytorch-VERSION', + }, + accelerator: tpus.v3_8, + + command: utils.scriptCommand( + ||| + cd pytorch-lightning + coverage run --source=pytorch_lightning -m pytest -v --capture=no \ + pytorch_lightning/utilities/xla_device_utils.py \ + tests/accelerators/test_tpu_backend.py \ + tests/models/test_tpu.py + test_exit_code=$? + echo "\n||| END PYTEST LOGS |||\n" + coverage xml + cat coverage.xml | tr -d '\t' + test $test_exit_code -eq 0 + ||| + ), +}; + +tputests.oneshotJob diff --git a/docs/.build_docs.sh b/docs/.build_docs.sh deleted file mode 100644 index 691f7fc22905b4..00000000000000 --- a/docs/.build_docs.sh +++ /dev/null @@ -1 +0,0 @@ -make clean ; make html --debug --jobs 2 SPHINXOPTS="-W" \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 69fe55ecfa9aad..ba501f6f5b1bfd 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -16,4 +16,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 0cbf169d706901..00000000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -sphinx>=2.0, <3.0 -recommonmark # fails with badges -m2r # fails with multi-line text -nbsphinx -pandoc -docutils -sphinxcontrib-fulltoc -sphinxcontrib-mockautodoc -git+https://github.com/PytorchLightning/lightning_sphinx_theme.git -# pip_shims -sphinx-autodoc-typehints -sphinx-paramlinks<0.4.0 diff --git a/docs/source/_images/general/tf_loss.png b/docs/source/_images/general/tf_loss.png deleted file mode 100644 index 5bc631e44dfe4e..00000000000000 Binary files a/docs/source/_images/general/tf_loss.png and /dev/null differ diff --git a/docs/source/_images/general/tf_tags.png b/docs/source/_images/general/tf_tags.png deleted file mode 100644 index 269c3bf6c7b80b..00000000000000 Binary files a/docs/source/_images/general/tf_tags.png and /dev/null differ diff --git a/docs/source/_images/lightning_module/pt_to_pl.png b/docs/source/_images/lightning_module/pt_to_pl.png deleted file mode 100644 index c5b24093f8311d..00000000000000 Binary files a/docs/source/_images/lightning_module/pt_to_pl.png and /dev/null differ diff --git a/docs/source/_images/lightning_module/pt_trainer.png b/docs/source/_images/lightning_module/pt_trainer.png deleted file mode 100644 index a6c5d5ee7441bf..00000000000000 Binary files a/docs/source/_images/lightning_module/pt_trainer.png and /dev/null differ diff --git a/docs/source/_images/logos/lightning_icon.svg b/docs/source/_images/logos/lightning_icon.svg deleted file mode 100644 index 5ab3512c0491ee..00000000000000 --- a/docs/source/_images/logos/lightning_icon.svg +++ /dev/null @@ -1,62 +0,0 @@ - - - - - - image/svg+xml - - - - - - - - - - diff --git a/docs/source/_images/logos/lightning_logo-large.svg b/docs/source/_images/logos/lightning_logo-large.svg deleted file mode 100644 index 4a6cd73fa4e10c..00000000000000 --- a/docs/source/_images/logos/lightning_logo-large.svg +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - image/svg+xml - - - - - - - - - diff --git a/docs/source/_images/logos/lightning_logo-name.svg b/docs/source/_images/logos/lightning_logo-name.svg deleted file mode 100755 index d684eb0fdcaadd..00000000000000 --- a/docs/source/_images/logos/lightning_logo-name.svg +++ /dev/null @@ -1,80 +0,0 @@ - - - - - - image/svg+xml - - - - - - - - long - Created with Sketch. - - - - PyTorch Lightning - - - diff --git a/docs/source/_images/logos/lightning_logo.png b/docs/source/_images/logos/lightning_logo.png deleted file mode 100644 index a28606b541632d..00000000000000 Binary files a/docs/source/_images/logos/lightning_logo.png and /dev/null differ diff --git a/docs/source/_images/logos/lightning_logo.svg b/docs/source/_images/logos/lightning_logo.svg deleted file mode 100644 index d8b56e58f9a4fe..00000000000000 --- a/docs/source/_images/logos/lightning_logo.svg +++ /dev/null @@ -1,62 +0,0 @@ - - - - - - image/svg+xml - - - - - - - - - - diff --git a/docs/source/_static/copybutton.js b/docs/source/_static/copybutton.js new file mode 100644 index 00000000000000..453363ce9e1aed --- /dev/null +++ b/docs/source/_static/copybutton.js @@ -0,0 +1,64 @@ +/* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */ +$(document).ready(function() { + /* Add a [>>>] button on the top-right corner of code samples to hide + * the >>> and ... prompts and the output and thus make the code + * copyable. */ + var div = $('.highlight-python .highlight,' + + '.highlight-python3 .highlight,' + + '.highlight-pycon .highlight,' + + '.highlight-default .highlight'); + var pre = div.find('pre'); + + // get the styles from the current theme + pre.parent().parent().css('position', 'relative'); + var hide_text = 'Hide the prompts and output'; + var show_text = 'Show the prompts and output'; + var border_width = pre.css('border-top-width'); + var border_style = pre.css('border-top-style'); + var border_color = pre.css('border-top-color'); + var button_styles = { + 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', + 'border-color': border_color, 'border-style': border_style, + 'border-width': border_width, 'color': border_color, 'text-size': '75%', + 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', + 'border-radius': '0 3px 0 0' + } + + // create and add the button to all the code blocks that contain >>> + div.each(function(index) { + var jthis = $(this); + if (jthis.find('.gp').length > 0) { + var button = $('>>>'); + button.css(button_styles) + button.attr('title', hide_text); + button.data('hidden', 'false'); + jthis.prepend(button); + } + // tracebacks (.gt) contain bare text elements that need to be + // wrapped in a span to work with .nextUntil() (see later) + jthis.find('pre:has(.gt)').contents().filter(function() { + return ((this.nodeType == 3) && (this.data.trim().length > 0)); + }).wrap(''); + }); + + // define the behavior of the button when it's clicked + $('.copybutton').click(function(e){ + e.preventDefault(); + var button = $(this); + if (button.data('hidden') === 'false') { + // hide the code output + button.parent().find('.go, .gp, .gt').hide(); + button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); + button.css('text-decoration', 'line-through'); + button.attr('title', show_text); + button.data('hidden', 'true'); + } else { + // show the code output + button.parent().find('.go, .gp, .gt').show(); + button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); + button.css('text-decoration', 'none'); + button.attr('title', hide_text); + button.data('hidden', 'false'); + } + }); +}); diff --git a/docs/source/_static/images/benchmarks/figure-parity-times.png b/docs/source/_static/images/benchmarks/figure-parity-times.png new file mode 100644 index 00000000000000..2e8c5899020d99 Binary files /dev/null and b/docs/source/_static/images/benchmarks/figure-parity-times.png differ diff --git a/docs/source/_static/images/general/PTL101_youtube_thumbnail.jpg b/docs/source/_static/images/general/PTL101_youtube_thumbnail.jpg new file mode 100644 index 00000000000000..a09dc43d47bb71 Binary files /dev/null and b/docs/source/_static/images/general/PTL101_youtube_thumbnail.jpg differ diff --git a/docs/source/_static/images/general/fast_2.gif b/docs/source/_static/images/general/fast_2.gif new file mode 100644 index 00000000000000..77c6f85c65cece Binary files /dev/null and b/docs/source/_static/images/general/fast_2.gif differ diff --git a/docs/source/_images/general/pl_overview.gif b/docs/source/_static/images/general/pl_overview.gif similarity index 100% rename from docs/source/_images/general/pl_overview.gif rename to docs/source/_static/images/general/pl_overview.gif diff --git a/docs/source/_images/general/pl_overview_flat.jpg b/docs/source/_static/images/general/pl_overview_flat.jpg similarity index 100% rename from docs/source/_images/general/pl_overview_flat.jpg rename to docs/source/_static/images/general/pl_overview_flat.jpg diff --git a/docs/source/_static/images/general/pl_quick_start_full_compressed.gif b/docs/source/_static/images/general/pl_quick_start_full_compressed.gif new file mode 100644 index 00000000000000..f7136d0a299739 Binary files /dev/null and b/docs/source/_static/images/general/pl_quick_start_full_compressed.gif differ diff --git a/docs/source/_static/images/general/tf_loss.jpg b/docs/source/_static/images/general/tf_loss.jpg new file mode 100644 index 00000000000000..869947f4c6eafa Binary files /dev/null and b/docs/source/_static/images/general/tf_loss.jpg differ diff --git a/docs/source/_static/images/general/tf_tags.jpg b/docs/source/_static/images/general/tf_tags.jpg new file mode 100644 index 00000000000000..40918ec0011d9f Binary files /dev/null and b/docs/source/_static/images/general/tf_tags.jpg differ diff --git a/docs/source/_static/images/general/tutorial_cover.jpg b/docs/source/_static/images/general/tutorial_cover.jpg new file mode 100644 index 00000000000000..1c0e7f31d53b69 Binary files /dev/null and b/docs/source/_static/images/general/tutorial_cover.jpg differ diff --git a/docs/source/_static/images/icon.svg b/docs/source/_static/images/icon.svg new file mode 100644 index 00000000000000..bed8f14dc1086b --- /dev/null +++ b/docs/source/_static/images/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/docs/source/_static/images/lightning_module/pt_to_pl.png b/docs/source/_static/images/lightning_module/pt_to_pl.png new file mode 100644 index 00000000000000..5135ec214b1de7 Binary files /dev/null and b/docs/source/_static/images/lightning_module/pt_to_pl.png differ diff --git a/docs/source/_static/images/lightning_module/pt_trainer.png b/docs/source/_static/images/lightning_module/pt_trainer.png new file mode 100644 index 00000000000000..f465d43a503bc2 Binary files /dev/null and b/docs/source/_static/images/lightning_module/pt_trainer.png differ diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png new file mode 100644 index 00000000000000..82c6fd106b1f1b Binary files /dev/null and b/docs/source/_static/images/logo.png differ diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg new file mode 100644 index 00000000000000..dca54b36403f84 --- /dev/null +++ b/docs/source/_static/images/logo.svg @@ -0,0 +1,70 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/docs/source/_images/mnist_imgs/mnist_cpu_bar.png b/docs/source/_static/images/mnist_imgs/mnist_cpu_bar.png similarity index 100% rename from docs/source/_images/mnist_imgs/mnist_cpu_bar.png rename to docs/source/_static/images/mnist_imgs/mnist_cpu_bar.png diff --git a/docs/source/_images/mnist_imgs/mnist_gpu.png b/docs/source/_static/images/mnist_imgs/mnist_gpu.png similarity index 100% rename from docs/source/_images/mnist_imgs/mnist_gpu.png rename to docs/source/_static/images/mnist_imgs/mnist_gpu.png diff --git a/docs/source/_images/mnist_imgs/mnist_tb.png b/docs/source/_static/images/mnist_imgs/mnist_tb.png similarity index 100% rename from docs/source/_images/mnist_imgs/mnist_tb.png rename to docs/source/_static/images/mnist_imgs/mnist_tb.png diff --git a/docs/source/_images/mnist_imgs/pt_to_pl.jpg b/docs/source/_static/images/mnist_imgs/pt_to_pl.jpg similarity index 100% rename from docs/source/_images/mnist_imgs/pt_to_pl.jpg rename to docs/source/_static/images/mnist_imgs/pt_to_pl.jpg diff --git a/docs/source/_images/mnist_imgs/restart_runtime.png b/docs/source/_static/images/mnist_imgs/restart_runtime.png similarity index 100% rename from docs/source/_images/mnist_imgs/restart_runtime.png rename to docs/source/_static/images/mnist_imgs/restart_runtime.png diff --git a/docs/source/_images/mnist_imgs/runtime_tpu.png b/docs/source/_static/images/mnist_imgs/runtime_tpu.png similarity index 100% rename from docs/source/_images/mnist_imgs/runtime_tpu.png rename to docs/source/_static/images/mnist_imgs/runtime_tpu.png diff --git a/docs/source/_images/mnist_imgs/tpu_fast.png b/docs/source/_static/images/mnist_imgs/tpu_fast.png similarity index 100% rename from docs/source/_images/mnist_imgs/tpu_fast.png rename to docs/source/_static/images/mnist_imgs/tpu_fast.png diff --git a/docs/source/_images/mnist_imgs/tpu_start.png b/docs/source/_static/images/mnist_imgs/tpu_start.png similarity index 100% rename from docs/source/_images/mnist_imgs/tpu_start.png rename to docs/source/_static/images/mnist_imgs/tpu_start.png diff --git a/docs/source/_images/trainer/lr_finder.png b/docs/source/_static/images/trainer/lr_finder.png similarity index 100% rename from docs/source/_images/trainer/lr_finder.png rename to docs/source/_static/images/trainer/lr_finder.png diff --git a/docs/source/_static/main.css b/docs/source/_static/main.css new file mode 100644 index 00000000000000..82aa8b338ad397 --- /dev/null +++ b/docs/source/_static/main.css @@ -0,0 +1,3 @@ +col { + width: 50% !important; +} diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst new file mode 100644 index 00000000000000..704ac61bf773d2 --- /dev/null +++ b/docs/source/_templates/autosummary/module.rst @@ -0,0 +1,41 @@ +{{ name | escape | underline }} + +.. currentmodule:: {{ fullname }} + +{% block functions %} +{% if functions %} +.. rubric:: Functions + +.. autosummary:: + :nosignatures: +{% for item in functions %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block classes %} +{% if classes %} +.. rubric:: Classes + +.. autosummary:: + :nosignatures: +{% for item in classes %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block exceptions %} +{% if exceptions %} +.. rubric:: Exceptions + +.. autosummary:: + :nosignatures: +{% for item in exceptions %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +.. automodule:: {{ fullname }} diff --git a/docs/source/_templates/classtemplate.rst b/docs/source/_templates/classtemplate.rst new file mode 100644 index 00000000000000..398a0ec07cb05d --- /dev/null +++ b/docs/source/_templates/classtemplate.rst @@ -0,0 +1,14 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline }} + +.. autoclass:: {{ name }} + :members: + + +.. + autogenerated from source/_templates/classtemplate.rst + note it does not have :inherited-members: diff --git a/docs/source/_templates/theme_variables.jinja b/docs/source/_templates/theme_variables.jinja index 6c6ac509bb68da..d2f00702fb6553 100644 --- a/docs/source/_templates/theme_variables.jinja +++ b/docs/source/_templates/theme_variables.jinja @@ -1,17 +1,17 @@ {%- set external_urls = { - 'github': 'https://github.com/PytorchLightning/pytorch-lightning', - 'github_issues': 'https://github.com/PytorchLightning/pytorch-lightning/issues', - 'contributing': 'https://github.com/PytorchLightning/pytorch-lightning/blob/master/CONTRIBUTING.md', - 'governance': 'https://github.com/PytorchLightning/pytorch-lightning/blob/master/governance.md', + 'github': 'https://github.com/PyTorchLightning/pytorch-lightning', + 'github_issues': 'https://github.com/PyTorchLightning/pytorch-lightning/issues', + 'contributing': 'https://github.com/PyTorchLightning/pytorch-lightning/blob/master/CONTRIBUTING.md', + 'governance': 'https://github.com/PyTorchLightning/pytorch-lightning/blob/master/governance.md', 'docs': 'https://pytorch-lightning.rtfd.io/en/latest', 'twitter': 'https://twitter.com/PyTorchLightnin', 'discuss': 'https://pytorch-lightning.slack.com', 'tutorials': 'https://pytorch-lightning.readthedocs.io/en/latest/#tutorials', 'previous_pytorch_versions': 'https://pytorch-lightning.rtfd.io/en/latest/', 'home': 'https://pytorch-lightning.rtfd.io/en/latest/', - 'get_started': 'https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html', + 'get_started': 'https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html', 'features': 'https://pytorch-lightning.rtfd.io/en/latest/', - 'blog': 'https://towardsdatascience.com/@_willfalcon', + 'blog': 'https://www.pytorchlightning.ai/blog', 'resources': 'https://pytorch-lightning.readthedocs.io/en/latest/#community-examples', 'support': 'https://pytorch-lightning.rtfd.io/en/latest/', } diff --git a/docs/source/advanced/amp.rst b/docs/source/advanced/amp.rst new file mode 100644 index 00000000000000..d42f1c8c2928d7 --- /dev/null +++ b/docs/source/advanced/amp.rst @@ -0,0 +1,97 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + +.. _amp: + +16-bit training +================= +Lightning offers 16-bit training for CPUs, GPUs, and TPUs. + +.. raw:: html + + + +| + + +---------- + +GPU 16-bit +---------- +16-bit precision can cut your memory footprint by half. +If using volta architecture GPUs it can give a dramatic training speed-up as well. + +.. note:: PyTorch 1.6+ is recommended for 16-bit + +Native torch +^^^^^^^^^^^^ +When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit. + +.. testcode:: + :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available() + + # turn on 16-bit + trainer = Trainer(precision=16, gpus=1) + +Apex 16-bit +^^^^^^^^^^^ +If you are using an earlier version of PyTorch Lightning uses Apex to support 16-bit. + +Follow these instructions to install Apex. +To use 16-bit precision, do two things: + +1. Install Apex +2. Set the "precision" trainer flag. + +.. code-block:: bash + + $ git clone https://github.com/NVIDIA/apex + $ cd apex + + # ------------------------ + # OPTIONAL: on your cluster you might need to load CUDA 10 or 9 + # depending on how you installed PyTorch + + # see available modules + module avail + + # load correct CUDA before install + module load cuda-10.0 + # ------------------------ + + # make sure you've loaded a cuda version > 4.0 and < 7.0 + module load gcc-6.1.0 + + $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + +.. warning:: NVIDIA Apex and DDP have instability problems. We recommend native 16-bit in PyTorch 1.6+ + +Enable 16-bit +^^^^^^^^^^^^^ + +.. testcode:: + :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available() + + # turn on 16-bit + trainer = Trainer(amp_level='O2', precision=16) + +If you need to configure the apex init for your particular use case or want to use a different way of doing +16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`. + +---------- + +TPU 16-bit +---------- +16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag + +.. testcode:: + :skipif: not _TPU_AVAILABLE + + # DEFAULT + trainer = Trainer(tpu_cores=8, precision=32) + + # turn on 16-bit + trainer = Trainer(tpu_cores=8, precision=16) diff --git a/docs/source/advanced/cluster.rst b/docs/source/advanced/cluster.rst new file mode 100644 index 00000000000000..34538cb85b4746 --- /dev/null +++ b/docs/source/advanced/cluster.rst @@ -0,0 +1,57 @@ + +.. _non-slurm: + +Computing cluster +================= + +With Lightning it is easy to run your training script on a computing cluster without almost any modifications to the script. +This guide shows how to run a training job on a general purpose cluster. + +Also, check :doc:`../extensions/accelerators` as a new and more general approach to a cluster setup. + +-------- + + +Cluster setup +------------- + +To setup a multi-node computing cluster you need: + +1) Multiple computers with PyTorch Lightning installed +2) A network connectivity between them with firewall rules that allow traffic flow on a specified *MASTER_PORT*. +3) Defined environment variables on each node required for the PyTorch Lightning multi-node distributed training + +PyTorch Lightning follows the design of `PyTorch distributed communication package `_. and requires the following environment variables to be defined on each node: + +- *MASTER_PORT* - required; has to be a free port on machine with NODE_RANK 0 +- *MASTER_ADDR* - required (except for NODE_RANK 0); address of NODE_RANK 0 node +- *WORLD_SIZE* - required; how many nodes are in the cluster +- *NODE_RANK* - required; id of the node in the cluster + + +Training script design +---------------------- + +To train a model using multiple nodes, do the following: + +1. Design your :ref:`lightning_module` (no need to add anything specific here). + +2. Enable DDP in the trainer + + .. code-block:: python + + # train on 32 GPUs across 4 nodes + trainer = Trainer(gpus=8, num_nodes=4, accelerator='ddp') + + +Submit a job to the cluster +--------------------------- + +To submit a training job to the cluster you need to run the same training script on each node of the cluster. +This means that you need to: + +1. Copy all third-party libraries to each node (usually means - distribute requirements.txt file and install it). + +2. Copy all your import dependencies and the script itself to each node. + +3. Run the script on each node. diff --git a/docs/source/advanced/lr_finder.rst b/docs/source/advanced/lr_finder.rst new file mode 100644 index 00000000000000..9a0749b36ad4ab --- /dev/null +++ b/docs/source/advanced/lr_finder.rst @@ -0,0 +1,114 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule + +.. _lr_finder: + +Learning Rate Finder +-------------------- + +.. raw:: html + + + +| + +For training deep neural networks, selecting a good learning rate is essential +for both better performance and faster convergence. Even optimizers such as +`Adam` that are self-adjusting the learning rate can benefit from more optimal +choices. + +To reduce the amount of guesswork concerning choosing a good initial learning +rate, a `learning rate finder` can be used. As described in this `paper `_ +a learning rate finder does a small run where the learning rate is increased +after each processed batch and the corresponding loss is logged. The result of +this is a `lr` vs. `loss` plot that can be used as guidance for choosing a optimal +initial lr. + +.. warning:: + For the moment, this feature only works with models having a single optimizer. + LR Finder support for DDP is not implemented yet, it is coming soon. + +---------- + +Using Lightning's built-in LR finder +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To enable the learning rate finder, your :doc:`lightning module <../common/lightning_module>` needs to have a ``learning_rate`` or ``lr`` property. +Then, set ``Trainer(auto_lr_find=True)`` during trainer construction, +and then call ``trainer.tune(model)`` to run the LR finder. The suggested ``learning_rate`` +will be written to the console and will be automatically set to your :doc:`lightning module <../common/lightning_module>`, +which can be accessed via ``self.learning_rate`` or ``self.lr``. + +.. code-block:: python + + class LitModel(LightningModule): + + def __init__(self, learning_rate): + self.learning_rate = learning_rate + + def configure_optimizers(self): + return Adam(self.parameters(), lr=(self.lr or self.learning_rate)) + + model = LitModel() + + # finds learning rate automatically + # sets hparams.lr or hparams.learning_rate to that learning rate + trainer = Trainer(auto_lr_find=True) + + trainer.tune(model) + +If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as ``auto_lr_find``: + +.. code-block:: python + + model = LitModel() + + # to set to your own hparams.my_value + trainer = Trainer(auto_lr_find='my_value') + + trainer.tune(model) + + +If you want to inspect the results of the learning rate finder or just play around +with the parameters of the algorithm, this can be done by invoking the ``lr_find`` +method of the trainer. A typical example of this would look like + +.. code-block:: python + + model = MyModelClass(hparams) + trainer = Trainer() + + # Run learning rate finder + lr_finder = trainer.tuner.lr_find(model) + + # Results can be found in + lr_finder.results + + # Plot with + fig = lr_finder.plot(suggest=True) + fig.show() + + # Pick point based on plot, or get suggestion + new_lr = lr_finder.suggestion() + + # update hparams of the model + model.hparams.lr = new_lr + + # Fit model + trainer.fit(model) + +The figure produced by ``lr_finder.plot()`` should look something like the figure +below. It is recommended to not pick the learning rate that achieves the lowest +loss, but instead something in the middle of the sharpest downward slope (red point). +This is the point returned py ``lr_finder.suggestion()``. + +.. figure:: ../_static/images/trainer/lr_finder.png + +The parameters of the algorithm can be seen below. + +.. autofunction:: pytorch_lightning.tuner.lr_finder.lr_find + :noindex: diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst new file mode 100644 index 00000000000000..5cdb0b377f2b75 --- /dev/null +++ b/docs/source/advanced/multi_gpu.rst @@ -0,0 +1,1032 @@ +.. testsetup:: * + + import torch + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule + +.. _multi_gpu: + +Multi-GPU training +================== +Lightning supports multiple ways of doing distributed training. + +.. raw:: html + + + +| + +---------- + +Preparing your code +------------------- +To train on CPU/GPU/TPU without changing your code, we need to build a few good habits :) + +Delete .cuda() or .to() calls +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Delete any calls to .cuda() or .to(device). + +.. testcode:: + + # before lightning + def forward(self, x): + x = x.cuda(0) + layer_1.cuda(0) + x_hat = layer_1(x) + + # after lightning + def forward(self, x): + x_hat = layer_1(x) + +Init tensors using type_as and register_buffer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +When you need to create a new tensor, use `type_as`. +This will make your code scale to any arbitrary number of GPUs or TPUs with Lightning. + +.. testcode:: + + # before lightning + def forward(self, x): + z = torch.Tensor(2, 3) + z = z.cuda(0) + + # with lightning + def forward(self, x): + z = torch.Tensor(2, 3) + z = z.type_as(x) + +The :class:`~pytorch_lightning.core.lightning.LightningModule` knows what device it is on. You can access the reference via ``self.device``. +Sometimes it is necessary to store tensors as module attributes. However, if they are not parameters they will +remain on the CPU even if the module gets moved to a new device. To prevent that and remain device agnostic, +register the tensor as a buffer in your modules's ``__init__`` method with :meth:`~torch.nn.Module.register_buffer`. + +.. testcode:: + + class LitModel(LightningModule): + + def __init__(self): + ... + self.register_buffer("sigma", torch.eye(3)) + # you can now access self.sigma anywhere in your module + + +Remove samplers +^^^^^^^^^^^^^^^ +In PyTorch, you must use :class:`~torch.utils.data.distributed.DistributedSampler` +for multi-node or TPU training. The sampler makes sure each GPU sees the appropriate part of your data. + +.. testcode:: + + # without lightning + def train_dataloader(self): + dataset = MNIST(...) + sampler = None + + if self.on_tpu: + sampler = DistributedSampler(dataset) + + return DataLoader(dataset, sampler=sampler) + +Lightning adds the correct samplers when needed, so no need to explicitly add samplers. + +.. testcode:: + + # with lightning + def train_dataloader(self): + dataset = MNIST(...) + return DataLoader(dataset) + +.. note:: + By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. + ``drop_last`` in :class:`~torch.utils.data.distributed.DistributedSampler` will be set to its default value in PyTorch. + +.. note:: You can disable this behavior with ``Trainer(replace_sampler_ddp=False)`` + +.. note:: For iterable datasets, we don't do this automatically. + + +Synchronize validation and test logging +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes. +This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. +This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers. + +Note if you use any built in metrics or custom metrics that use the :doc:`Metrics API <../extensions/metrics>`, these do not need to be updated and are automatically handled for you. + +.. testcode:: + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.loss(logits, y) + # Add sync_dist=True to sync logging across all GPU workers + self.log('validation_loss', loss, on_step=True, on_epoch=True, sync_dist=True) + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.loss(logits, y) + # Add sync_dist=True to sync logging across all GPU workers + self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True) + + +Make models pickleable +^^^^^^^^^^^^^^^^^^^^^^ +It's very likely your code is already `pickleable `_, +in that case no change in necessary. +However, if you run a distributed model and get the following error: + +.. code-block:: + + self._launch(process_obj) + File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47, + in _launch reduction.dump(process_obj, fp) + File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/reduction.py", line 60, in dump + ForkingPickler(file, protocol).dump(obj) + _pickle.PicklingError: Can't pickle at 0x2b599e088ae8>: + attribute lookup on __main__ failed + +This means something in your model definition, transforms, optimizer, dataloader or callbacks cannot be pickled, and the following code will fail: + +.. code-block:: python + + import pickle + pickle.dump(some_object) + +This is a limitation of using multiple processes for distributed training within PyTorch. +To fix this issue, find your piece of code that cannot be pickled. The end of the stacktrace +is usually helpful. +ie: in the stacktrace example here, there seems to be a lambda function somewhere in the code +which cannot be pickled. + +.. code-block:: + + self._launch(process_obj) + File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47, + in _launch reduction.dump(process_obj, fp) + File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/reduction.py", line 60, in dump + ForkingPickler(file, protocol).dump(obj) + _pickle.PicklingError: Can't pickle [THIS IS THE THING TO FIND AND DELETE]: + attribute lookup on __main__ failed + +---------- + +Select GPU devices +------------------ + +You can select the GPU devices using ranges, a list of indices or a string containing +a comma separated list of GPU ids: + +.. testsetup:: + + k = 1 + +.. testcode:: + :skipif: torch.cuda.device_count() < 2 + + # DEFAULT (int) specifies how many GPUs to use per node + Trainer(gpus=k) + + # Above is equivalent to + Trainer(gpus=list(range(k))) + + # Specify which GPUs to use (don't use when running on cluster) + Trainer(gpus=[0, 1]) + + # Equivalent using a string + Trainer(gpus='0, 1') + + # To use all available GPUs put -1 or '-1' + # equivalent to list(range(torch.cuda.device_count())) + Trainer(gpus=-1) + +The table below lists examples of possible input formats and how they are interpreted by Lightning. +Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`. + ++---------------+-----------+---------------------+---------------------------------+ +| `gpus` | Type | Parsed | Meaning | ++===============+===========+=====================+=================================+ +| None | NoneType | None | CPU | ++---------------+-----------+---------------------+---------------------------------+ +| 0 | int | None | CPU | ++---------------+-----------+---------------------+---------------------------------+ +| 3 | int | [0, 1, 2] | first 3 GPUs | ++---------------+-----------+---------------------+---------------------------------+ +| -1 | int | [0, 1, 2, ...] | all available GPUs | ++---------------+-----------+---------------------+---------------------------------+ +| [0] | list | [0] | GPU 0 | ++---------------+-----------+---------------------+---------------------------------+ +| [1, 3] | list | [1, 3] | GPUs 1 and 3 | ++---------------+-----------+---------------------+---------------------------------+ +| "0" | str | [0] | GPU 0 | ++---------------+-----------+---------------------+---------------------------------+ +| "3" | str | [3] | GPU 3 | ++---------------+-----------+---------------------+---------------------------------+ +| "1, 3" | str | [1, 3] | GPUs 1 and 3 | ++---------------+-----------+---------------------+---------------------------------+ +| "-1" | str | [0, 1, 2, ...] | all available GPUs | ++---------------+-----------+---------------------+---------------------------------+ + +.. note:: + + When specifying number of gpus as an integer ``gpus=k``, setting the trainer flag + ``auto_select_gpus=True`` will automatically help you find ``k`` gpus that are not + occupied by other processes. This is especially useful when GPUs are configured + to be in "exclusive mode", such that only one process at a time can access them. + For more details see the :doc:`trainer guide <../common/trainer>`. + + +Select torch distributed backend +-------------------------------- + +By default, Lightning will select the ``nccl`` backend over ``gloo`` when running on GPUs. +Find more information about PyTorch's supported backends `here `__. + +Lightning exposes an environment variable ``PL_TORCH_DISTRIBUTED_BACKEND`` for the user to change the backend. + +.. code-block:: bash + + PL_TORCH_DISTRIBUTED_BACKEND=gloo python train.py ... + + +---------- + +Distributed modes +----------------- +Lightning allows multiple ways of training + +- Data Parallel (``accelerator='dp'``) (multiple-gpus, 1 machine) +- DistributedDataParallel (``accelerator='ddp'``) (multiple-gpus across many machines (python script based)). +- DistributedDataParallel (``accelerator='ddp_spawn'``) (multiple-gpus across many machines (spawn based)). +- DistributedDataParallel 2 (``accelerator='ddp2'``) (DP in a machine, DDP across machines). +- Horovod (``accelerator='horovod'``) (multi-machine, multi-gpu, configured at runtime) +- TPUs (``tpu_cores=8|x``) (tpu or TPU pod) + +.. note:: + If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used. + +For a deeper understanding of what Lightning is doing, feel free to read this +`guide `_. + + + +Data Parallel +^^^^^^^^^^^^^ +:class:`~torch.nn.DataParallel` (DP) splits a batch across k GPUs. +That is, if you have a batch of 32 and use DP with 2 gpus, each GPU will process 16 samples, +after which the root node will aggregate the results. + +.. warning:: DP use is discouraged by PyTorch and Lightning. Use DDP which is more stable and at least 3x faster + +.. testcode:: + :skipif: torch.cuda.device_count() < 2 + + # train on 2 GPUs (using DP mode) + trainer = Trainer(gpus=2, accelerator='dp') + +Distributed Data Parallel +^^^^^^^^^^^^^^^^^^^^^^^^^ +:class:`~torch.nn.parallel.DistributedDataParallel` (DDP) works as follows: + +1. Each GPU across each node gets its own process. + +2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset. + +3. Each process inits the model. + +4. Each process performs a full forward and backward pass in parallel. + +5. The gradients are synced and averaged across all processes. + +6. Each process updates its optimizer. + +.. code-block:: python + + # train on 8 GPUs (same machine (ie: node)) + trainer = Trainer(gpus=8, accelerator='ddp') + + # train on 32 GPUs (4 nodes) + trainer = Trainer(gpus=8, accelerator='ddp', num_nodes=4) + +This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment +variables: + +.. code-block:: bash + + # example for 3 GPUs DDP + MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --gpus 3 --etc + MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python my_file.py --gpus 3 --etc + MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python my_file.py --gpus 3 --etc + +We use DDP this way because `ddp_spawn` has a few limitations (due to Python and PyTorch): + +1. Since `.spawn()` trains the model in subprocesses, the model on the main process does not get updated. +2. Dataloader(num_workers=N), where N is large, bottlenecks training with DDP... ie: it will be VERY slow or won't work at all. This is a PyTorch limitation. +3. Forces everything to be picklable. + +There are cases in which it is NOT possible to use DDP. Examples are: + +- Jupyter Notebook, Google COLAB, Kaggle, etc. +- You have a nested script without a root package + +In these situations you should use `dp` or `ddp_spawn` instead. + +Distributed Data Parallel 2 +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +In certain cases, it's advantageous to use all batches on the same machine instead of a subset. +For instance, you might want to compute a NCE loss where it pays to have more negative samples. + +In this case, we can use DDP2 which behaves like DP in a machine and DDP across nodes. DDP2 does the following: + +1. Copies a subset of the data to each node. + +2. Inits a model on each node. + +3. Runs a forward and backward pass using DP. + +4. Syncs gradients across nodes. + +5. Applies the optimizer updates. + +.. code-block:: python + + # train on 32 GPUs (4 nodes) + trainer = Trainer(gpus=8, accelerator='ddp2', num_nodes=4) + +Distributed Data Parallel Spawn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +`ddp_spawn` is exactly like `ddp` except that it uses .spawn to start the training processes. + +.. warning:: It is STRONGLY recommended to use `DDP` for speed and performance. + +.. code-block:: python + + mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, )) + +If your script does not support being called from the command line (ie: it is nested without a root +project module) you can use the following method: + +.. code-block:: python + + # train on 8 GPUs (same machine (ie: node)) + trainer = Trainer(gpus=8, accelerator='ddp_spawn') + +We STRONGLY discourage this use because it has limitations (due to Python and PyTorch): + +1. The model you pass in will not update. Please save a checkpoint and restore from there. +2. Set Dataloader(num_workers=0) or it will bottleneck training. + +`ddp` is MUCH faster than `ddp_spawn`. We recommend you + +1. Install a top-level module for your project using setup.py + +.. code-block:: python + + # setup.py + #!/usr/bin/env python + + from setuptools import setup, find_packages + + setup(name='src', + version='0.0.1', + description='Describe Your Cool Project', + author='', + author_email='', + url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK + install_requires=[ + 'pytorch-lightning' + ], + packages=find_packages() + ) + +2. Setup your project like so: + +.. code-block:: bash + + /project + /src + some_file.py + /or_a_folder + setup.py + +3. Install as a root-level package + +.. code-block:: bash + + cd /project + pip install -e . + +You can then call your scripts anywhere + +.. code-block:: bash + + cd /project/src + python some_file.py --accelerator 'ddp' --gpus 8 + + +Horovod +^^^^^^^ +`Horovod `_ allows the same training script to be used for single-GPU, +multi-GPU, and multi-node training. + +Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed +subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass, +then synchronously applied before beginning the next step. + +The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In +the training script, Horovod will detect the number of workers from the environment, and automatically +scale the learning rate to compensate for the increased total batch size. + +Horovod can be configured in the training script to run with any number of GPUs / processes as follows: + +.. code-block:: python + + # train Horovod on GPU (number of GPUs / machines provided on command-line) + trainer = Trainer(accelerator='horovod', gpus=1) + + # train Horovod on CPU (number of processes / machines provided on command-line) + trainer = Trainer(accelerator='horovod') + +When starting the training job, the driver application will then be used to specify the total +number of worker processes: + +.. code-block:: bash + + # run training with 4 GPUs on a single machine + horovodrun -np 4 python train.py + + # run training with 8 GPUs on two machines (4 GPUs each) + horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py + +See the official `Horovod documentation `_ for details +on installation and performance tuning. + +DP/DDP2 caveats +^^^^^^^^^^^^^^^ +In DP and DDP2 each GPU within a machine sees a portion of a batch. +DP and ddp2 roughly do the following: + +.. testcode:: + + def distributed_forward(batch, model): + batch = torch.Tensor(32, 8) + gpu_0_batch = batch[:8] + gpu_1_batch = batch[8:16] + gpu_2_batch = batch[16:24] + gpu_3_batch = batch[24:] + + y_0 = model_copy_gpu_0(gpu_0_batch) + y_1 = model_copy_gpu_1(gpu_1_batch) + y_2 = model_copy_gpu_2(gpu_2_batch) + y_3 = model_copy_gpu_3(gpu_3_batch) + + return [y_0, y_1, y_2, y_3] + +So, when Lightning calls any of the `training_step`, `validation_step`, `test_step` +you will only be operating on one of those pieces. + +.. testcode:: + + # the batch here is a portion of the FULL batch + def training_step(self, batch, batch_idx): + y_0 = batch + +For most metrics, this doesn't really matter. However, if you want +to add something to your computational graph (like softmax) +using all batch parts you can use the `training_step_end` step. + +.. testcode:: + + def training_step_end(self, outputs): + # only use when on dp + outputs = torch.cat(outputs, dim=1) + softmax = softmax(outputs, dim=1) + out = softmax.mean() + return out + +In pseudocode, the full sequence is: + +.. code-block:: python + + # get data + batch = next(dataloader) + + # copy model and data to each gpu + batch_splits = split_batch(batch, num_gpus) + models = copy_model_to_gpus(model) + + # in parallel, operate on each batch chunk + all_results = [] + for gpu_num in gpus: + batch_split = batch_splits[gpu_num] + gpu_model = models[gpu_num] + out = gpu_model(batch_split) + all_results.append(out) + + # use the full batch for something like softmax + full out = model.training_step_end(all_results) + +To illustrate why this is needed, let's look at DataParallel + +.. testcode:: + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(batch) + + # on dp or ddp2 if we did softmax now it would be wrong + # because batch is actually a piece of the full batch + return y_hat + + def training_step_end(self, batch_parts_outputs): + # batch_parts_outputs has outputs of each part of the batch + + # do softmax here + outputs = torch.cat(outputs, dim=1) + softmax = softmax(outputs, dim=1) + out = softmax.mean() + + return out + +If `training_step_end` is defined it will be called regardless of TPU, DP, DDP, etc... which means +it will behave the same regardless of the backend. + +Validation and test step have the same option when using DP. + +.. testcode:: + + def validation_step_end(self, batch_parts_outputs): + ... + + def test_step_end(self, batch_parts_outputs): + ... + + +Distributed and 16-bit precision +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Due to an issue with Apex and DataParallel (PyTorch and NVIDIA issue), Lightning does +not allow 16-bit and DP training. We tried to get this to work, but it's an issue on their end. + +Below are the possible configurations we support. + ++-------+---------+----+-----+--------+------------------------------------------------------------+ +| 1 GPU | 1+ GPUs | DP | DDP | 16-bit | command | ++=======+=========+====+=====+========+============================================================+ +| Y | | | | | `Trainer(gpus=1)` | ++-------+---------+----+-----+--------+------------------------------------------------------------+ +| Y | | | | Y | `Trainer(gpus=1, precision=16)` | ++-------+---------+----+-----+--------+------------------------------------------------------------+ +| | Y | Y | | | `Trainer(gpus=k, accelerator='dp')` | ++-------+---------+----+-----+--------+------------------------------------------------------------+ +| | Y | | Y | | `Trainer(gpus=k, accelerator='ddp')` | ++-------+---------+----+-----+--------+------------------------------------------------------------+ +| | Y | | Y | Y | `Trainer(gpus=k, accelerator='ddp', precision=16)` | ++-------+---------+----+-----+--------+------------------------------------------------------------+ + + +Implement Your Own Distributed (DDP) training +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +If you need your own way to init PyTorch DDP you can override :meth:`pytorch_lightning.plugins.training_type.ddp.DDPPlugin.init_ddp_connection`. + +If you also need to use your own DDP implementation, override :meth:`pytorch_lightning.plugins.training_type.ddp.DDPPlugin.configure_ddp`. + + +---------- + +.. _model-parallelism: + +Model Parallelism [BETA] +------------------------ + +Model Parallelism tackles training large models on distributed systems, by modifying distributed communications and memory management of the model. +Unlike data parallelism, the model is partitioned in various ways across the GPUs, in most cases to reduce the memory overhead when training large models. +This is useful when dealing with large Transformer based models, or in environments where GPU memory is limited. + +Lightning currently offers the following methods to leverage model parallelism: + +- Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**) +- Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential ` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization) + +.. _sharded: + +Sharded Training +^^^^^^^^^^^^^^^^ +Lightning integration of optimizer sharded training provided by `FairScale `_. +The technique can be found within `DeepSpeed ZeRO `_ and +`ZeRO-2 `_, +however the implementation is built from the ground up to be pytorch compatible and standalone. +Sharded Training allows you to maintain GPU scaling efficiency, whilst reducing memory overhead drastically. In short, expect normal linear scaling, and significantly reduced memory usage when training large models. + +Sharded Training still utilizes Data Parallel Training under the hood, except optimizer states and gradients are sharded across GPUs. +This means the memory overhead per GPU is lower, as each GPU only has to maintain a partition of your optimizer state and gradients. + +The benefits vary by model and parameter sizes, but we've recorded up to a 63% memory reduction per GPU allowing us to double our model sizes. Because of extremely efficient communication, +these benefits in multi-GPU setups are almost free and throughput scales well with multi-node setups. + +Below we use the `NeMo Transformer Lightning Language Modeling example `_ to benchmark the maximum batch size and model size that can be fit on 8 A100 GPUs for DDP vs Sharded Training. +Note that the benefits can still be obtained using 2 or more GPUs, and for even larger batch sizes you can scale to multiple nodes. + +**Increase Your Batch Size** + +Use Sharded Training to scale your batch size further using the same compute. This will reduce your overall epoch time. + ++----------------------+-----------------------+----------------+---------------------+ +| Distributed Training | Model Size (Millions) | Max Batch Size | Percentage Gain (%) | ++======================+=======================+================+=====================+ +| Native DDP | 930 | 32 | - | ++----------------------+-----------------------+----------------+---------------------+ +| Sharded DDP | 930 | **52** | **48%** | ++----------------------+-----------------------+----------------+---------------------+ + +**Increase Your Model Size** + +Use Sharded Training to scale your model size further using the same compute. + ++----------------------+------------+---------------------------+---------------------+ +| Distributed Training | Batch Size | Max Model Size (Millions) | Percentage Gain (%) | ++======================+============+===========================+=====================+ +| Native DDP | 32 | 930 | - | ++----------------------+------------+---------------------------+---------------------+ +| Sharded DDP | 32 | **1404** | **41%** | ++----------------------+------------+---------------------------+---------------------+ +| Native DDP | 8 | 1572 | - | ++----------------------+------------+---------------------------+---------------------+ +| Sharded DDP | 8 | **2872** | **59%** | ++----------------------+------------+---------------------------+---------------------+ + +It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models). +A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful. +Work within the future will bring optional sharding to activations and model parameters to reduce memory further, but come with a speed cost. + +To use Sharded Training, you need to first install FairScale using the command below. + +.. code-block:: bash + + pip install fairscale + + +.. code-block:: python + + # train using Sharded DDP + trainer = Trainer(accelerator='ddp', plugins='ddp_sharded') + +Sharded Training can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag. + +Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required. + +---------- + +.. _deep_speed: + +DeepSpeed +^^^^^^^^^ + +.. note:: + The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue `_ if you run into any issues. + +`DeepSpeed `_ is a deep learning training optimization library, providing the means to train massive billion parameter models at scale. +Using the DeepSpeed plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark `_ and the DeepSpeed `docs `_. +DeepSpeed also offers lower level training optimizations, and efficient optimizers such as `1-bit Adam `_. We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). + +To use DeepSpeed, you first need to install DeepSpeed using the commands below. + +.. code-block:: bash + + pip install deepspeed + +If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). + +.. note:: + Currently ``resume_from_checkpoint`` and manual optimization are not supported. + + DeepSpeed currently only supports single optimizer, single scheduler within the training loop. + +DeepSpeed ZeRO Stage 2 +"""""""""""""""""""""" + +By default, we enable `DeepSpeed ZeRO Stage 2 `_, which partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team. +As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the plugin described in a few sections below. + +.. note:: + To use ZeRO, you must use ``precision=16``. + +.. code-block:: python + + from pytorch_lightning import Trainer + + model = MyModel() + trainer = Trainer(gpus=4, plugins='deepspeed', precision=16) + trainer.fit(model) + + +DeepSpeed ZeRO Stage 2 Offload +"""""""""""""""""""""""""""""" + +Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. + +.. note:: + To use ZeRO-Offload, you must use ``precision=16``. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16) + trainer.fit(model) + + +This can also be done via the command line using a Pytorch Lightning script: + +.. code-block:: bash + + python train.py --plugins deepspeed --precision 16 --gpus 4 + + +You can also modify the ZeRO-Offload parameters via the plugin as below. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) + trainer.fit(model) + + +.. note:: + We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size. + These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed. + + DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameters. + + The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``. + +For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM called `DeepSpeedCPUAdam `_ to run the offloaded computation, which is faster than the standard PyTorch implementation. + +.. code-block:: python + + import pytorch_lightning + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + from deepspeed.ops.adam import DeepSpeedCPUAdam + + class MyModel(pl.LightningModule): + ... + def configure_optimizers(self): + # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) + return DeepSpeedCPUAdam(self.parameters()) + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16) + trainer.fit(model) + + +Custom DeepSpeed Config +""""""""""""""""""""""" + +In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We've exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported. + +.. note:: + All plugin default parameters will be ignored when a config object is passed. + All compatible arguments can be seen in the `DeepSpeed docs `_. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + deepspeed_config = { + "zero_allow_untested_optimizer": True, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 3e-5, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": 1e-9, + "cuda_aware": True, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + }, + "zero_optimization": { + "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) + "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU + "contiguous_gradients": True, # Reduce gradient fragmentation. + "overlap_comm": True, # Overlap reduce/backward operation of gradients for speed. + "allgather_bucket_size": 2e8, # Number of elements to all gather at once. + "reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once. + } + } + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16) + trainer.fit(model) + + +We support taking the config as a json formatted file: + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16) + trainer.fit(model) + + +You can use also use an environment variable via your PyTorch Lightning script: + +.. code-block:: bash + + PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed + + +---------- + +.. _sequential-parallelism: + +Sequential Model Parallelism with Checkpointing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +PyTorch Lightning integration for Sequential Model Parallelism using `FairScale `_. +Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially. +We also provide auto-balancing techniques through FairScale, to find optimal balances for the model across GPUs. +In addition, we use Gradient Checkpointing to reduce GPU memory requirements further, and micro-batches to minimizing device under-utilization automatically. + +Reference: https://arxiv.org/abs/1811.06965 + +.. note:: RPCSequentialPlugin is currently supported only for Pytorch 1.6. + +To get started, install FairScale using the command below. We install a specific branch which contains PyTorch related fixes for Sequential Parallelism. + +.. code-block:: bash + + pip install https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip + +To use Sequential Model Parallelism, you must define a :class:`nn.Sequential ` module that defines the layers you wish to parallelize across GPUs. +This should be kept within the ``sequential_module`` variable within your ``LightningModule`` like below. + +.. code-block:: python + + from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin + from pytorch_lightning import LightningModule + + class MyModel(LightningModule): + def __init__(self): + ... + self.sequential_module = nn.Sequential(my_layers) + + # Split my module across 4 gpus, one layer each + model = MyModel() + plugin = RPCSequentialPlugin(balance=[1, 1, 1, 1]) + trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin]) + trainer.fit(model) + + +We provide a minimal example of Sequential Model Parallelism using a convolutional model training on cifar10, split onto GPUs `here `_. +To run the example, you need to install `Bolts `_. Install with ``pip install pytorch-lightning-bolts``. + +When running the Sequential Model Parallelism example on 2 GPUS we achieve these memory savings. + +.. list-table:: GPU Memory Utilization + :widths: 25 25 50 + :header-rows: 1 + + * - GPUS + - Without Balancing + - With Balancing + * - Gpu 0 + - 4436 MB + - 1554 MB + * - Gpu 1 + - ~0 + - 994 MB + +To run the example with Sequential Model Parallelism: + +.. code-block:: bash + + python pl_examples/basic_examples/conv_sequential_example.py --batch_size 1024 --gpus 2 --accelerator ddp --use_ddp_sequential + +To run the same example without Sequential Model Parallelism: + +.. code-block:: bash + + python pl_examples/basic_examples/conv_sequential_example.py --batch_size 1024 --gpus 1 + + +Batch size +---------- +When using distributed training make sure to modify your learning rate according to your effective +batch size. + +Let's say you have a batch size of 7 in your dataloader. + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + return Dataset(..., batch_size=7) + +In (DDP, Horovod) your effective batch size will be 7 * gpus * num_nodes. + +.. code-block:: python + + # effective batch size = 7 * 8 + Trainer(gpus=8, accelerator='ddp|horovod') + + # effective batch size = 7 * 8 * 10 + Trainer(gpus=8, num_nodes=10, accelerator='ddp|horovod') + + +In DDP2, your effective batch size will be 7 * num_nodes. +The reason is that the full batch is visible to all GPUs on the node when using DDP2. + +.. code-block:: python + + # effective batch size = 7 + Trainer(gpus=8, accelerator='ddp2') + + # effective batch size = 7 * 10 + Trainer(gpus=8, num_nodes=10, accelerator='ddp2') + + +.. note:: Huge batch sizes are actually really bad for convergence. Check out: + `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour `_ + +---------- + +TorchElastic +-------------- +Lightning supports the use of TorchElastic to enable fault-tolerant and elastic distributed job scheduling. To use it, specify the 'ddp' or 'ddp2' backend and the number of gpus you want to use in the trainer. + +.. code-block:: python + + Trainer(gpus=8, accelerator='ddp') + + +Following the `TorchElastic Quickstart documentation `_, you then need to start a single-node etcd server on one of the hosts: + +.. code-block:: bash + + etcd --enable-v2 + --listen-client-urls http://0.0.0.0:2379,http://127.0.0.1:4001 + --advertise-client-urls PUBLIC_HOSTNAME:2379 + + +And then launch the elastic job with: + +.. code-block:: bash + + python -m torchelastic.distributed.launch + --nnodes=MIN_SIZE:MAX_SIZE + --nproc_per_node=TRAINERS_PER_NODE + --rdzv_id=JOB_ID + --rdzv_backend=etcd + --rdzv_endpoint=ETCD_HOST:ETCD_PORT + YOUR_LIGHTNING_TRAINING_SCRIPT.py (--arg1 ... train script args...) + + +See the official `TorchElastic documentation `_ for details +on installation and more use cases. + +---------- + +Jupyter Notebooks +----------------- +Unfortunately any `ddp_` is not supported in jupyter notebooks. Please use `dp` for multiple GPUs. This is a known +Jupyter issue. If you feel like taking a stab at adding this support, feel free to submit a PR! + +---------- + +Pickle Errors +-------------- +Multi-GPU training sometimes requires your model to be pickled. If you run into an issue with pickling +try the following to figure out the issue + +.. code-block:: python + + import pickle + + model = YourModel() + pickle.dumps(model) + +However, if you use `ddp` the pickling requirement is not there and you should be fine. If you use `ddp_spawn` the +pickling requirement remains. This is a limitation of Python. diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst new file mode 100644 index 00000000000000..1a82641953c3ce --- /dev/null +++ b/docs/source/advanced/multiple_loaders.rst @@ -0,0 +1,179 @@ +.. testsetup:: * + + from pytorch_lightning.core.lightning import LightningModule + +.. _multiple_loaders: + +Multiple Datasets +================= +Lightning supports multiple dataloaders in a few ways. + +1. Create a dataloader that iterates multiple datasets under the hood. +2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning + will automatically combine the batches from different loaders. +3. In the validation and test loop you also have the option to return multiple dataloaders + which lightning will call sequentially. + +---------- + +.. _multiple-training-dataloaders: + +Multiple training dataloaders +----------------------------- +For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class +which wraps your multiple dataloaders (this of course also works for testing and validation +dataloaders). + +(`reference `_) + +.. testcode:: + + class ConcatDataset(torch.utils.data.Dataset): + def __init__(self, *datasets): + self.datasets = datasets + + def __getitem__(self, i): + return tuple(d[i] for d in self.datasets) + + def __len__(self): + return min(len(d) for d in self.datasets) + + class LitModel(LightningModule): + + def train_dataloader(self): + concat_dataset = ConcatDataset( + datasets.ImageFolder(traindir_A), + datasets.ImageFolder(traindir_B) + ) + + loader = torch.utils.data.DataLoader( + concat_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True + ) + return loader + + def val_dataloader(self): + # SAME + ... + + def test_dataloader(self): + # SAME + ... + +However, with lightning you can also return multiple loaders and lightning will take care of batch combination. + +For more details please have a look at :paramref:`~pytorch_lightning.trainer.trainer.Trainer.multiple_trainloader_mode` + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(6), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(15), batch_size=5) + + # pass loaders as a dict. This will create batches like this: + # {'a': batch from loader_a, 'b': batch from loader_b} + loaders = {'a': loader_a, + 'b': loader_b} + + # OR: + # pass loaders as sequence. This will create batches like this: + # [batch from loader_a, batch from loader_b] + loaders = [loader_a, loader_b] + + return loaders + +Furthermore, Lightning also supports that nested lists and dicts (or a combination) can +be returned. + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=2) + + return {'a': loader_a, 'b': loader_b} + + def training_step(self, batch, batch_idx): + # access a dictionnary with a batch from each dataloader + batch_a = batch["a"] + batch_b = batch["b"] + + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(32), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) + + # pass loaders as a nested dict. This will create batches like this: + loaders = { + 'loaders_a_b': { + 'a': loader_a, + 'b': loader_b + }, + 'loaders_c_d': { + 'c': loader_c, + 'd': loader_d + } + } + return loaders + + def training_step(self, batch, batch_idx): + # access the data + batch_a_b = batch["loaders_a_b"] + batch_c_d = batch["loaders_c_d"] + + batch_a = batch_a_b["a"] + batch_b = batch_a_b["a"] + + batch_c = batch_c_d["c"] + batch_d = batch_c_d["d"] + +---------- + +Test/Val dataloaders +-------------------- +For validation and test dataloaders, lightning also gives you the additional +option of passing multiple dataloaders back from each call. You can choose to pass +the batches sequentially or simultaneously, as is done for the training step. +The default mode for validation and test dataloaders is sequential. + +See the following for more details for the default sequential option: + +- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader` +- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader` + +.. testcode:: + + def val_dataloader(self): + loader_1 = Dataloader() + loader_2 = Dataloader() + return [loader_1, loader_2] + +To combine batches of multiple test and validation dataloaders simultaneously, one +needs to wrap the dataloaders with `CombinedLoader`. + +.. testcode:: + + from pytorch_lightning.trainer.supporters import CombinedLoader + + def val_dataloader(self): + loader_1 = Dataloader() + loader_2 = Dataloader() + loaders = {'a': loader_a,'b': loader_b} + combined_loaders = CombinedLoader(loaders, "max_size_cycle") + return combined_loaders diff --git a/docs/source/profiler.rst b/docs/source/advanced/profiler.rst similarity index 94% rename from docs/source/profiler.rst rename to docs/source/advanced/profiler.rst index 115aaf27597563..26ad4156551c00 100644 --- a/docs/source/profiler.rst +++ b/docs/source/advanced/profiler.rst @@ -1,6 +1,7 @@ .. role:: hidden :class: hidden-section +.. _profiler: Performance and Bottleneck Profiler =================================== diff --git a/docs/source/advanced/pruning_quantization.rst b/docs/source/advanced/pruning_quantization.rst new file mode 100644 index 00000000000000..cd3ae2065db767 --- /dev/null +++ b/docs/source/advanced/pruning_quantization.rst @@ -0,0 +1,119 @@ +.. testsetup:: * + + import os + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule + +.. _pruning_quantization: + +######################## +Pruning and Quantization +######################## + +Pruning and Quantization are techniques to compress model size for deployment, allowing inference speed up and energy saving without significant accuracy losses. + +******* +Pruning +******* + +.. warning:: + + Pruning is in beta and subject to change. + +Pruning is a technique which focuses on eliminating some of the model weights to reduce the model size and decrease inference requirements. + +Pruning has been shown to achieve significant efficiency improvements while minimizing the drop in model performance (prediction quality). Model pruning is recommended for cloud endpoints, deploying models on edge devices, or mobile inference (among others). + +To enable pruning during training in Lightning, simply pass in the :class:`~pytorch_lightning.callbacks.ModelPruning` callback to the Lightning Trainer. PyTorch's native pruning implementation is used under the hood. + +This callback supports multiple pruning functions: pass any `torch.nn.utils.prune `_ function as a string to select which weights to prune (`random_unstructured `_, `RandomStructured `_, etc) or implement your own by subclassing `BasePruningMethod `_. + +.. code-block:: python + + from pytorch_lightning.callbacks import ModelPruning + + # set the amount to be the fraction of parameters to prune + trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=0.5)]) + +You can also perform iterative pruning, apply the `lottery ticket hypothesis `__, and more! + +.. code-block:: python + + def compute_amount(epoch): + # the sum of all returned values need to be smaller than 1 + if epoch == 10: + return 0.5 + + elif epoch == 50: + return 0.25 + + elif 75 < epoch < 99 : + return 0.01 + + # the amount can be also be a callable + trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=compute_amount)]) + + +************ +Quantization +************ + +.. warning :: + Quantization is in beta and subject to change. + +Model quantization is another performance optimization technique that allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating-point precision. This is particularly beneficial during model deployment. + +Quantization Aware Training (QAT) mimics the effects of quantization during training: The computations are carried-out in floating-point precision but the subsequent quantization effect is taken into account. The weights and activations are quantized into lower precision only for inference, when training is completed. + +Quantization is useful when it is required to serve large models on machines with limited memory, or when there's a need to switch between models and reducing the I/O time is important. For example, switching between monolingual speech recognition models across multiple languages. + +Lightning includes :class:`~pytorch_lightning.callbacks.QuantizationAwareTraining` callback (using PyTorch's native quantization, read more `here `__), which allows creating fully quantized models (compatible with torchscript). + +.. code-block:: python + + from pytorch_lightning.callbacks import QuantizationAwareTraining + + class RegressionModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer_0 = nn.Linear(16, 64) + self.layer_0a = torch.nn.ReLU() + self.layer_1 = nn.Linear(64, 64) + self.layer_1a = torch.nn.ReLU() + self.layer_end = nn.Linear(64, 1) + + def forward(self, x): + x = self.layer_0(x) + x = self.layer_0a(x) + x = self.layer_1(x) + x = self.layer_1a(x) + x = self.layer_end(x) + return x + + trainer = Trainer(callbacks=[QuantizationAwareTraining()]) + qmodel = RegressionModel() + trainer.fit(qmodel, ...) + + batch = iter(my_dataloader()).next() + qmodel(qmodel.quant(batch[0])) + + tsmodel = qmodel.to_torchscript() + tsmodel(tsmodel.quant(batch[0])) + +You can further customize the callback: + +.. code-block:: python + + + qcb = QuantizationAwareTraining( + # specification of quant estimation quality + observer_type='histogram', + # specify which layers shall be merged together to increase efficiency + modules_to_fuse=[(f'layer_{i}', f'layer_{i}a') for i in range(2)] + # make your model compatible with all original input/outputs, in such case the model is wrapped in a shell with entry/exit layers. + input_compatible=True + ) + + batch = iter(my_dataloader()).next() + qmodel(batch[0]) diff --git a/docs/source/sequences.rst b/docs/source/advanced/sequences.rst similarity index 64% rename from docs/source/sequences.rst rename to docs/source/advanced/sequences.rst index 857fd08198de85..759a671cc42efe 100644 --- a/docs/source/sequences.rst +++ b/docs/source/advanced/sequences.rst @@ -3,16 +3,19 @@ from torch.utils.data import IterableDataset from pytorch_lightning.trainer.trainer import Trainer +.. _sequences: + Sequential Data ================ Lightning has built in support for dealing with sequential data. +---------- Packed sequences as inputs ----------------------------- +-------------------------- When using PackedSequence, do 2 things: -1. return either a padded tensor in dataset or a list of variable length tensors in the dataloader collate_fn (example above shows the list implementation). +1. Return either a padded tensor in dataset or a list of variable length tensors in the dataloader collate_fn (example shows the list implementation). 2. Pack the sequence in forward or training and validation steps depending on use case. .. testcode:: @@ -28,8 +31,10 @@ When using PackedSequence, do 2 things: x = rnn.pack_sequence(batch[0], enforce_sorted=False) y = rnn.pack_sequence(batch[1], enforce_sorted=False) +---------- + Truncated Backpropagation Through Time ---------------------------------------- +-------------------------------------- There are times when multiple backwards passes are needed for each batch. For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs. @@ -46,18 +51,23 @@ Lightning can handle TBTT automatically via this flag. .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. -.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include - a `hiddens` arg. +.. note:: Using this feature requires updating your LightningModule's + :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg. + +---------- Iterable Datasets ---------------------------------------- +----------------- Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural option when using sequential data. -.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int - (specifying the number of training batches to run before validation) when initializing the Trainer. - This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate - the validation interval when val_check_interval is less than one. +.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int + (specifying the number of training batches to run before validation) when initializing the Trainer. This is + because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation + interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or + an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches`` + to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception. + Here mode can be train/val/test. .. testcode:: @@ -82,3 +92,9 @@ option when using sequential data. # Set val_check_interval trainer = Trainer(val_check_interval=100) + + # Set limit_val_batches to 0.0 or 0 + trainer = Trainer(limit_val_batches=0.0) + + # Set limit_val_batches as an int + trainer = Trainer(limit_val_batches=100) diff --git a/docs/source/advanced/tpu.rst b/docs/source/advanced/tpu.rst new file mode 100644 index 00000000000000..b9688ce425b5ff --- /dev/null +++ b/docs/source/advanced/tpu.rst @@ -0,0 +1,284 @@ +.. _tpu: + +TPU support +=========== + +.. raw:: html + + + +| + +Lightning supports running on TPUs. At this moment, TPUs are available +on Google Cloud (GCP), Google Colab and Kaggle Environments. For more information on TPUs +`watch this video `_. + +---------------- + +TPU Terminology +--------------- +A TPU is a Tensor processing unit. Each TPU has 8 cores where each +core is optimized for 128x128 matrix multiplies. In general, a single +TPU is about as fast as 5 V100 GPUs! + +A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! +You can request a full pod from Google cloud or a "slice" which gives you +some subset of those 2048 cores. + +---------------- + +How to access TPUs +------------------ +To access TPUs, there are three main ways. + +1. Using Google Colab. +2. Using Google Cloud (GCP). +3. Using Kaggle. + +---------------- + +Kaggle TPUs +----------- +For starting Kaggle projects with TPUs, refer to this `kernel `_. + +--------- + +Colab TPUs +---------- +Colab is like a jupyter notebook with a free GPU or TPU +hosted on GCP. + +To get a TPU on colab, follow these steps: + +1. Go to `https://colab.research.google.com/ `_. + +2. Click "new notebook" (bottom right of pop-up). + +3. Click runtime > change runtime settings. Select Python 3, and hardware accelerator "TPU". + This will give you a TPU with 8 cores. + +4. Next, insert this code into the first cell and execute. + This will install the xla library that interfaces between PyTorch and the TPU. + + .. code-block:: + + !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py + !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev + +5. Once the above is done, install PyTorch Lightning (v 0.7.0+). + + .. code-block:: + + !pip install pytorch-lightning + +6. Then set up your LightningModule as normal. + +---------------- + +DistributedSamplers +------------------- +Lightning automatically inserts the correct samplers - no need to do this yourself! + +Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right +chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning + +.. note:: Don't add distributedSamplers. Lightning does this automatically + +If for some reason you still need to, this is how to construct the sampler +for TPU use + +.. code-block:: python + + import torch_xla.core.xla_model as xm + + def train_dataloader(self): + dataset = MNIST( + os.getcwd(), + train=True, + download=True, + transform=transforms.ToTensor() + ) + + # required for TPU support + sampler = None + if use_tpu: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=True + ) + + loader = DataLoader( + dataset, + sampler=sampler, + batch_size=32 + ) + + return loader + +Configure the number of TPU cores in the trainer. You can only choose 1 or 8. +To use a full TPU pod skip to the TPU pod section. + +.. code-block:: python + + import pytorch_lightning as pl + + my_model = MyLightningModule() + trainer = pl.Trainer(tpu_cores=8) + trainer.fit(my_model) + +That's it! Your model will train on all 8 TPU cores. + +---------------- + +TPU core training +----------------- + +Lightning supports training on a single TPU core or 8 TPU cores. + +The Trainer parameters ``tpu_cores`` defines how many TPU cores to train on (1 or 8) / Single TPU to train on [1]. + +For Single TPU training, Just pass the TPU core ID [1-8] in a list. + +Single TPU core training. Model will train on TPU core ID 5. + +.. code-block:: python + + trainer = pl.Trainer(tpu_cores=[5]) + +8 TPU cores training. Model will train on 8 TPU cores. + +.. code-block:: python + + trainer = pl.Trainer(tpu_cores=8) + +---------------- + +Distributed Backend with TPU +---------------------------- +The ``accelerator`` option used for GPUs does not apply to TPUs. +TPUs work in DDP mode by default (distributing over each core) + +---------------- + +TPU Pod +------- +To train on more than 8 cores, your code actually doesn't change! +All you need to do is submit the following command: + +.. code-block:: bash + + $ python -m torch_xla.distributed.xla_dist + --tpu=$TPU_POD_NAME + --conda-env=torch-xla-nightly + -- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data + +See `this guide `_ +on how to set up the instance groups and VMs needed to run TPU Pods. + +---------------- + +16 bit precision +---------------- +Lightning also supports training in 16-bit precision with TPUs. +By default, TPU training will use 32-bit precision. To enable 16-bit, +set the 16-bit flag. + +.. code-block:: python + + import pytorch_lightning as pl + + my_model = MyLightningModule() + trainer = pl.Trainer(tpu_cores=8, precision=16) + trainer.fit(my_model) + +Under the hood the xla library will use the `bfloat16 type `_. + + +----------------- + +Weight Sharing/Tying +-------------------- +Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers. +This is a common method to reduce memory consumption and is utilized in many State of the Art +architectures today. + +PyTorch XLA requires these weights to be tied/shared after moving the model +to the TPU device. To support this requirement Lightning provides a model hook which is +called after the model is moved to the device. Any weights that require to be tied should +be done in the `on_post_move_to_device` model hook. This will ensure that the weights +among the modules are shared and not copied. + +PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths +match once the model is moved to the device. If the lengths do not match Lightning +throws a warning message. + +Example: + +.. code-block:: python + + from pytorch_lightning.core.lightning import LightningModule + from torch import nn + from pytorch_lightning.trainer.trainer import Trainer + + + class WeightSharingModule(LightningModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + # TPU shared weights are copied independently + # on the XLA device and this line won't have any effect. + # However, it works fine for CPU and GPU. + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + def on_post_move_to_device(self): + # Weights shared after the model has been moved to TPU Device + self.layer_3.weight = self.layer_1.weight + + + model = WeightSharingModule() + trainer = Trainer(max_epochs=1, tpu_cores=8) + +See `XLA Documentation `_ + +----------------------- + +Performance considerations +-------------------------- + +The TPU was designed for specific workloads and operations to carry out large volumes of matrix multiplication, +convolution operations and other commonly used ops in applied deep learning. +The specialization makes it a strong choice for NLP tasks, sequential convolutional networks, and under low precision operation. +There are cases in which training on TPUs is slower when compared with GPUs, for possible reasons listed: + +- Too small batch size. +- Explicit evaluation of tensors during training, e.g. ``tensor.item()`` +- Tensor shapes (e.g. model inputs) change often during training. +- Limited resources when using TPU's with PyTorch `Link `_ +- XLA Graph compilation during the initial steps `Reference `_ +- Some tensor ops are not fully supported on TPU, or not supported at all. These operations will be performed on CPU (context switch). +- PyTorch integration is still experimental. Some performance bottlenecks may simply be the result of unfinished implementation. + +The official PyTorch XLA `performance guide `_ +has more detailed information on how PyTorch code can be optimized for TPU. In particular, the +`metrics report `_ allows +one to identify operations that lead to context switching. + + +About XLA +---------- +XLA is the library that interfaces PyTorch with the TPUs. +For more information check out `XLA `_. + +Guide for `troubleshooting XLA `_ diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst new file mode 100644 index 00000000000000..4f7452c2da1de4 --- /dev/null +++ b/docs/source/advanced/training_tricks.rst @@ -0,0 +1,151 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + +.. _training_tricks: + +Training Tricks +================ +Lightning implements various tricks to help during training + +---------- + +Accumulate gradients +-------------------- +Accumulated gradients runs K small batches of size N before doing a backwards pass. +The effect is a large effective batch size of size KxN. + +.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` + +.. testcode:: + + # DEFAULT (ie: no accumulated grads) + trainer = Trainer(accumulate_grad_batches=1) + +---------- + +Gradient Clipping +----------------- +Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient +norm `_ computed over all model parameters together. + +.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` + +.. testcode:: + + # DEFAULT (ie: don't clip) + trainer = Trainer(gradient_clip_val=0) + + # clip gradients with norm above 0.5 + trainer = Trainer(gradient_clip_val=0.5) + +---------- + +Stochastic Weight Averaging +--------------------------- +Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost. +This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making +it harder to end up in a local minimum during optimization. + +For a more detailed explanation of SWA and how it works, +read `this `__ post by the PyTorch team. + +.. seealso:: :class:`~pytorch_lightning.callbacks.StochasticWeightAveraging` (Callback) + +.. testcode:: + + # Enable Stochastic Weight Averaging + trainer = Trainer(stochastic_weight_avg=True) + +---------- + +Auto scaling of batch size +-------------------------- +Auto scaling of batch size may be enabled to find the largest batch size that fits into +memory. Larger batch size often yields better estimates of gradients, but may also result in +longer training time. Inspired by https://github.com/BlackHC/toma. + +.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` + +.. code-block:: python + + # DEFAULT (ie: don't scale batch size automatically) + trainer = Trainer(auto_scale_batch_size=None) + + # Autoscale batch size + trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch') + + # find the batch size + trainer.tune(model) + +Currently, this feature supports two modes `'power'` scaling and `'binsearch'` +scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling +the batch size until an out-of-memory (OOM) error is encountered. Setting the +argument to `'binsearch'` will initially also try doubling the batch size until +it encounters an OOM, after which it will do a binary search that will finetune the +batch size. Additionally, it should be noted that the batch size scaler cannot +search for batch sizes larger than the size of the training dataset. + + +.. note:: + + This feature expects that a `batch_size` field is either located as a model attribute + i.e. `model.batch_size` or as a field in your `hparams` i.e. `model.hparams.batch_size`. + The field should exist and will be overridden by the results of this algorithm. + Additionally, your `train_dataloader()` method should depend on this field + for this feature to work i.e. + + .. code-block:: python + + def train_dataloader(self): + return DataLoader(train_dataset, batch_size=self.batch_size|self.hparams.batch_size) + +.. warning:: + + Due to these constraints, this features does *NOT* work when passing dataloaders directly + to `.fit()`. + +The scaling algorithm has a number of parameters that the user can control by +invoking the trainer method `.scale_batch_size` themself (see description below). + +.. code-block:: python + + # Use default in trainer construction + trainer = Trainer() + tuner = Tuner(trainer) + + # Invoke method + new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here) + + # Override old batch size + model.hparams.batch_size = new_batch_size + + # Fit as normal + trainer.fit(model) + +The algorithm in short works by: + 1. Dumping the current state of the model and trainer + 2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached: + - Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of + training steps. Each training step can trigger an OOM error if the tensors + (training batch, weights, gradients, etc.) allocated during the steps have a + too large memory footprint. + - If an OOM error is encountered, decrease batch size else increase it. + How much the batch size is increased/decreased is determined by the chosen + strategy. + 3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size` + 4. Restore the initial state of model and trainer + +.. autoclass:: pytorch_lightning.tuner.tuning.Tuner + :noindex: + :members: scale_batch_size + +.. warning:: Batch size finder is not supported for DDP yet, it is coming soon. + + +Sequential Model Parallelism with Checkpointing +--------------------------------------------------------------------- +PyTorch Lightning integration for Sequential Model Parallelism using `FairScale `_. +Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially. + +For more information, refer to :ref:`sequential-parallelism`. diff --git a/docs/source/transfer_learning.rst b/docs/source/advanced/transfer_learning.rst similarity index 86% rename from docs/source/transfer_learning.rst rename to docs/source/advanced/transfer_learning.rst index 35b7d661f07c42..72d16a9f2bf115 100644 --- a/docs/source/transfer_learning.rst +++ b/docs/source/advanced/transfer_learning.rst @@ -46,24 +46,28 @@ Example: Imagenet (computer Vision) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. testcode:: - :skipif: not TORCHVISION_AVAILABLE + :skipif: not _TORCHVISION_AVAILABLE import torchvision.models as models class ImagenetTransferLearning(LightningModule): def __init__(self): + super().__init__() + # init a pretrained resnet - num_target_classes = 10 - self.feature_extractor = models.resnet50( - pretrained=True, - num_classes=num_target_classes) - self.feature_extractor.eval() + backbone = models.resnet50(pretrained=True) + num_filters = backbone.fc.in_features + layers = list(backbone.children())[:-1] + self.feature_extractor = nn.Sequential(*layers) # use the pretrained model to classify cifar-10 (10 image classes) - self.classifier = nn.Linear(2048, num_target_classes) + num_target_classes = 10 + self.classifier = nn.Linear(num_filters, num_target_classes) def forward(self, x): - representations = self.feature_extractor(x) + self.feature_extractor.eval() + with torch.no_grad(): + representations = self.feature_extractor(x).flatten(1) x = self.classifier(representations) ... @@ -115,4 +119,4 @@ Here's a model that uses `Huggingface transformers 4.0 and < 7.0 - module load gcc-6.1.0 - - $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ - - -Enable 16-bit -^^^^^^^^^^^^^ - -.. testcode:: - - # turn on 16-bit - trainer = Trainer(amp_level='O1', precision=16) - -If you need to configure the apex init for your particular use case or want to use a different way of doing -16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`. - -TPU 16-bit ----------- -16-bit on TPus is much simpler. To use 16-bit with TPUs set precision to 16 when using the tpu flag - -.. testcode:: - - # DEFAULT - trainer = Trainer(num_tpu_cores=8, precision=32) - - # turn on 16-bit - trainer = Trainer(num_tpu_cores=8, precision=16) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst new file mode 100644 index 00000000000000..e9520dea8045fe --- /dev/null +++ b/docs/source/api_references.rst @@ -0,0 +1,97 @@ +API References +============== + +Core API +-------- + +.. currentmodule:: pytorch_lightning.core + +.. autosummary:: + :toctree: api + :nosignatures: + + datamodule + decorators + hooks + lightning + +Callbacks API +------------- + +.. currentmodule:: pytorch_lightning.callbacks + +.. autosummary:: + :toctree: api + :nosignatures: + + base + early_stopping + gpu_stats_monitor + gradient_accumulation_scheduler + lr_monitor + model_checkpoint + progress + +Loggers API +----------- + +.. currentmodule:: pytorch_lightning.loggers + +.. autosummary:: + :toctree: api + :nosignatures: + + base + comet + csv_logs + mlflow + neptune + tensorboard + test_tube + wandb + +Profiler API +------------ + +.. currentmodule:: pytorch_lightning.profiler + +.. autosummary:: + :toctree: api + :nosignatures: + + profilers + +Trainer API +----------- + +.. currentmodule:: pytorch_lightning.trainer + +.. autosummary:: + :toctree: api + :nosignatures: + + trainer + +Tuner API +--------- + +.. currentmodule:: pytorch_lightning.tuner + +.. autosummary:: + :toctree: api + :nosignatures: + + batch_size_scaling + lr_finder + +Utilities API +------------- + +.. currentmodule:: pytorch_lightning.utilities + +.. autosummary:: + :toctree: api + :nosignatures: + + argparse_utils + seed diff --git a/docs/source/benchmarking/benchmarks.rst b/docs/source/benchmarking/benchmarks.rst new file mode 100644 index 00000000000000..f5a2e4e19b7fa8 --- /dev/null +++ b/docs/source/benchmarking/benchmarks.rst @@ -0,0 +1,14 @@ +Benchmark with vanilla PyTorch +============================== + +In this section we set grounds for comparison between vanilla PyTorch and PT Lightning for most common scenarios. + +Time comparison +--------------- + +We have set regular benchmarking against PyTorch vanilla training loop on with RNN and simple MNIST classifier as per of out CI. +In average for simple MNIST CNN classifier we are only about 0.06s slower per epoch, see detail chart bellow. + +.. figure:: ../_static/images/benchmarks/figure-parity-times.png + :alt: Speed parity to vanilla PT, created on 2020-12-16 + :width: 500 diff --git a/docs/source/benchmarking/performance.rst b/docs/source/benchmarking/performance.rst new file mode 100644 index 00000000000000..dbddaad3a5e3c1 --- /dev/null +++ b/docs/source/benchmarking/performance.rst @@ -0,0 +1,199 @@ +.. _performance: + +Fast performance tips +===================== +Lightning builds in all the micro-optimizations we can find to increase your performance. +But we can only automate so much. + +Here are some additional things you can do to increase your performance. + +---------- + +Dataloaders +----------- +When building your DataLoader set ``num_workers > 0`` and ``pin_memory=True`` (only for GPUs). + +.. code-block:: python + + Dataloader(dataset, num_workers=8, pin_memory=True) + +num_workers +^^^^^^^^^^^ +The question of how many ``num_workers`` is tricky. Here's a summary of +some references, [`1 `_], and our suggestions. + +1. ``num_workers=0`` means ONLY the main process will load batches (that can be a bottleneck). +2. ``num_workers=1`` means ONLY one worker (just not the main process) will load data but it will still be slow. +3. The ``num_workers`` depends on the batch size and your machine. +4. A general place to start is to set ``num_workers`` equal to the number of CPUs on that machine. + +.. warning:: Increasing ``num_workers`` will ALSO increase your CPU memory consumption. + +The best thing to do is to increase the ``num_workers`` slowly and stop once you see no more improvement in your training speed. + +Spawn +^^^^^ +When using ``accelerator=ddp_spawn`` (the ddp default) or TPU training, the way multiple GPUs/TPU cores are used is by calling ``.spawn()`` under the hood. +The problem is that PyTorch has issues with ``num_workers > 0`` when using ``.spawn()``. For this reason we recommend you +use ``accelerator=ddp`` so you can increase the ``num_workers``, however your script has to be callable like so: + +.. code-block:: bash + + python my_program.py --gpus X + +---------- + +.item(), .numpy(), .cpu() +------------------------- +Don't call ``.item()`` anywhere in your code. Use ``.detach()`` instead to remove the connected graph calls. Lightning +takes a great deal of care to be optimized for this. + +---------- + +empty_cache() +------------- +Don't call this unnecessarily! Every time you call this ALL your GPUs have to wait to sync. + +---------- + +Construct tensors directly on the device +---------------------------------------- +LightningModules know what device they are on! Construct tensors on the device directly to avoid CPU->Device transfer. + +.. code-block:: python + + # bad + t = torch.rand(2, 2).cuda() + + # good (self is LightningModule) + t = torch.rand(2, 2, device=self.device) + + +For tensors that need to be model attributes, it is best practice to register them as buffers in the modules's +``__init__`` method: + +.. code-block:: python + + # bad + self.t = torch.rand(2, 2, device=self.device) + + # good + self.register_buffer("t", torch.rand(2, 2)) + +---------- + +Use DDP not DP +-------------- +DP performs three GPU transfers for EVERY batch: + +1. Copy model to device. +2. Copy data to device. +3. Copy outputs of each device back to master. + +| + +Whereas DDP only performs 1 transfer to sync gradients. Because of this, DDP is MUCH faster than DP. + +When using DDP set find_unused_parameters=False +----------------------------------------------- + +By default we have enabled find unused parameters to True. This is for compatibility issues that have arisen in the past (see the `discussion `_ for more information). +This by default comes with a performance hit, and can be disabled in most cases. + +.. code-block:: python + + from pytorch_lightning.plugins import DDPPlugin + + trainer = pl.Trainer( + gpus=2, + plugins=DDPPlugin(find_unused_parameters=False), + ) + +---------- + +16-bit precision +---------------- +Use 16-bit to decrease the memory consumption (and thus increase your batch size). On certain GPUs (V100s, 2080tis), 16-bit calculations are also faster. +However, know that 16-bit and multi-processing (any DDP) can have issues. Here are some common problems. + +1. `CUDA error: an illegal memory access was encountered `_. + The solution is likely setting a specific CUDA, CUDNN, PyTorch version combination. +2. ``CUDA error: device-side assert triggered``. This is a general catch-all error. To see the actual error run your script like so: + +.. code-block:: bash + + # won't see what the error is + python main.py + + # will see what the error is + CUDA_LAUNCH_BLOCKING=1 python main.py + +.. tip:: We also recommend using 16-bit native found in PyTorch 1.6. Just install this version and Lightning will automatically use it. + +---------- + +Use Sharded DDP for GPU memory and scaling optimization +------------------------------------------------------- + +Sharded DDP is a lightning integration of `DeepSpeed ZeRO `_ and +`ZeRO-2 `_ +provided by `Fairscale `_. + +When training on multiple GPUs sharded DDP can assist to increase memory efficiency substantially, and in some cases performance on multi-node is better than traditional DDP. +This is due to efficient communication and parallelization under the hood. + +To use Optimizer Sharded Training, refer to :ref:`model-parallelism`. + +Sharded DDP can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag. + +Refer to the :doc:`distributed computing guide for more details <../advanced/multi_gpu>`. + + +Sequential Model Parallelism with Checkpointing +----------------------------------------------- +PyTorch Lightning integration for Sequential Model Parallelism using `FairScale `_. +Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially. + +For more information, refer to :ref:`sequential-parallelism`. + + +Preload Data Into RAM +--------------------- + +When your training or preprocessing requires many operations to be performed on entire dataset(s) it can +sometimes be beneficial to store all data in RAM given there is enough space. +However, loading all data at the beginning of the training script has the disadvantage that it can take a long +time and hence it slows down the development process. Another downside is that in multiprocessing (e.g. DDP) +the data would get copied in each process. +One can overcome these problems by copying the data into RAM in advance. +Most UNIX-based operating systems provide direct access to tmpfs through a mount point typically named ``/dev/shm``. + +0. Increase shared memory if necessary. Refer to the documentation of your OS how to do this. + +1. Copy training data to shared memory: + + .. code-block:: bash + + cp -r /path/to/data/on/disk /dev/shm/ + +2. Refer to the new data root in your script or command line arguments: + + .. code-block:: python + + datamodule = MyDataModule(data_root="/dev/shm/my_data") + + +Zero Grad ``set_to_none=True`` +------------------------------ + +In order to modestly improve performance, once can override :meth:`~pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad`. + +For a more detailed explanation of pros / cons of this technique, +read `this `_ documentation by the PyTorch team. + +.. testcode:: + + class Model(LightningModule): + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad(set_to_none=True) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst deleted file mode 100644 index 744c1f0c5edd67..00000000000000 --- a/docs/source/callbacks.rst +++ /dev/null @@ -1,101 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.callbacks.base import Callback - -.. role:: hidden - :class: hidden-section - -.. _callbacks: - -Callbacks -========= - -Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL -logic that is NOT required for your :class:`~pytorch_lightning.core.LightningModule` to run. - -An overall Lightning system should have: - -1. Trainer for all engineering -2. LightningModule for all research code. -3. Callbacks for non-essential code. - - -Example: - -.. testcode:: - - class MyPrintingCallback(Callback): - - def on_init_start(self, trainer): - print('Starting to init trainer!') - - def on_init_end(self, trainer): - print('trainer is init now') - - def on_train_end(self, trainer, pl_module): - print('do something when training ends') - - trainer = Trainer(callbacks=[MyPrintingCallback()]) - -.. testoutput:: - - Starting to init trainer! - trainer is init now - -We successfully extended functionality without polluting our super clean -:class:`~pytorch_lightning.core.LightningModule` research code. - ---------- - -.. automodule:: pytorch_lightning.callbacks.base - :noindex: - :exclude-members: - _del_model, - _save_model, - _abc_impl, - check_monitor_top_k, - ---------- - -.. automodule:: pytorch_lightning.callbacks.early_stopping - :noindex: - :exclude-members: - _del_model, - _save_model, - _abc_impl, - check_monitor_top_k, - ---------- - -.. automodule:: pytorch_lightning.callbacks.model_checkpoint - :noindex: - :exclude-members: - _del_model, - _save_model, - _abc_impl, - check_monitor_top_k, - ---------- - -.. automodule:: pytorch_lightning.callbacks.gradient_accumulation_scheduler - :noindex: - :exclude-members: - _del_model, - _save_model, - _abc_impl, - check_monitor_top_k, - ---------- - -.. automodule:: pytorch_lightning.callbacks.progress - :noindex: - :exclude-members: - ---------- - -.. automodule:: pytorch_lightning.callbacks.lr_logger - :noindex: - :exclude-members: - _extract_lr, - _find_names \ No newline at end of file diff --git a/docs/source/clouds/cloud_training.rst b/docs/source/clouds/cloud_training.rst new file mode 100644 index 00000000000000..127bee6478dfd6 --- /dev/null +++ b/docs/source/clouds/cloud_training.rst @@ -0,0 +1,29 @@ +################ +AWS/GCP training +################ +Lightning has a native solution for training on AWS/GCP at scale (Lightning-Grid). +Grid is in private early-access now but you can request access at `grid.ai `_. + +We've designed Grid to work for Lightning users without needing to make ANY changes to their code. + +To use grid, take your regular command: + +.. code-block:: bash + + python my_model.py --learning_rate 1e-6 --layers 2 --gpus 4 + +And change it to use the grid train command: + +.. code-block:: bash + + grid train --grid_gpus 4 my_model.py --learning_rate 'uniform(1e-6, 1e-1, 20)' --layers '[2, 4, 8, 16]' + +The above command will launch (20 * 4) experiments each running on 4 GPUs (320 GPUs!) - by making ZERO changes to +your code. + +The `uniform` command is part of our new expressive syntax which lets you construct hyperparameter combinations +using over 20+ distributions, lists, etc. Of course, you can also configure all of this using yamls which +can be dynamically assembled at runtime. + + +.. hint:: Grid supports the search strategy of your choice! (and much more than just sweeps) diff --git a/docs/source/clouds/slurm.rst b/docs/source/clouds/slurm.rst new file mode 100644 index 00000000000000..d482dc77ab456e --- /dev/null +++ b/docs/source/clouds/slurm.rst @@ -0,0 +1,208 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + +.. _slurm: + +Computing cluster (SLURM) +========================= + +Lightning automates the details behind training on a SLURM-powered cluster. + +.. _multi-node: + +---------- + +Multi-node training +------------------- +To train a model using multiple nodes, do the following: + +1. Design your :doc:`lightning module <../common/lightning_module>`. + +2. Enable DDP in the trainer + + .. code-block:: python + + # train on 32 GPUs across 4 nodes + trainer = Trainer(gpus=8, num_nodes=4, accelerator='ddp') + +3. It's a good idea to structure your training script like this: + + .. testcode:: + + # train.py + def main(hparams): + model = LightningTemplateModel(hparams) + + trainer = Trainer( + gpus=8, + num_nodes=4, + accelerator='ddp' + ) + + trainer.fit(model) + + + if __name__ == '__main__': + root_dir = os.path.dirname(os.path.realpath(__file__)) + parent_parser = ArgumentParser(add_help=False) + hyperparams = parser.parse_args() + + # TRAIN + main(hyperparams) + +4. Create the appropriate SLURM job: + + .. code-block:: bash + + # (submit.sh) + #!/bin/bash -l + + # SLURM SUBMIT SCRIPT + #SBATCH --nodes=4 + #SBATCH --gres=gpu:8 + #SBATCH --ntasks-per-node=8 + #SBATCH --mem=0 + #SBATCH --time=0-02:00:00 + + # activate conda env + source activate $1 + + # debugging flags (optional) + export NCCL_DEBUG=INFO + export PYTHONFAULTHANDLER=1 + + # on your cluster you might need these: + # set the network interface + # export NCCL_SOCKET_IFNAME=^docker0,lo + + # might need the latest CUDA + # module load NCCL/2.4.7-1-cuda.10.0 + + # run script from above + srun python3 train.py + +5. If you want auto-resubmit (read below), add this line to the submit.sh script + + .. code-block:: bash + + #SBATCH --signal=SIGUSR1@90 + +6. Submit the SLURM job + + .. code-block:: bash + + sbatch submit.sh + +.. note:: + When running in DDP mode, any errors in your code will show up as an NCCL issue. + Set the `NCCL_DEBUG=INFO` flag to see the ACTUAL error. + + +Normally now you would need to add a +:class:`~torch.utils.data.distributed.DistributedSampler` to your dataset, however +Lightning automates this for you. But if you still need to set a sampler set the Trainer flag +:paramref:`~pytorch_lightning.Trainer.replace_sampler_ddp` to ``False``. + +Here's an example of how to add your own sampler (again, not needed with Lightning). + +.. testcode:: + + # in your LightningModule + def train_dataloader(self): + dataset = MyDataset() + dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + dataloader = Dataloader(dataset, sampler=dist_sampler) + return dataloader + + # in your training script + trainer = Trainer(replace_sampler_ddp=False) + +---------- + +Wall time auto-resubmit +----------------------- +When you use Lightning in a SLURM cluster, it automatically detects when it is about +to run into the wall time and does the following: + +1. Saves a temporary checkpoint. +2. Requeues the job. +3. When the job starts, it loads the temporary checkpoint. + +To get this behavior make sure to add the correct signal to your SLURM script + +.. code-block:: bash + + # 90 seconds before training ends + SBATCH --signal=SIGUSR1@90 + +---------- + +Building SLURM scripts +---------------------- + +Instead of manually building SLURM scripts, you can use the +`SlurmCluster object `_ +to do this for you. The SlurmCluster can also run a grid search if you pass +in a `HyperOptArgumentParser +`_. + +Here is an example where you run a grid search of 9 combinations of hyperparameters. +See also the multi-node examples +`here `__. + +.. code-block:: python + + # grid search 3 values of learning rate and 3 values of number of layers for your net + # this generates 9 experiments (lr=1e-3, layers=16), (lr=1e-3, layers=32), + # (lr=1e-3, layers=64), ... (lr=1e-1, layers=64) + parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) + parser.opt_list('--learning_rate', default=0.001, type=float, + options=[1e-3, 1e-2, 1e-1], tunable=True) + parser.opt_list('--layers', default=1, type=float, options=[16, 32, 64], tunable=True) + hyperparams = parser.parse_args() + + # Slurm cluster submits 9 jobs, each with a set of hyperparams + cluster = SlurmCluster( + hyperparam_optimizer=hyperparams, + log_path='/some/path/to/save', + ) + + # OPTIONAL FLAGS WHICH MAY BE CLUSTER DEPENDENT + # which interface your nodes use for communication + cluster.add_command('export NCCL_SOCKET_IFNAME=^docker0,lo') + + # see the output of the NCCL connection process + # NCCL is how the nodes talk to each other + cluster.add_command('export NCCL_DEBUG=INFO') + + # setting a master port here is a good idea. + cluster.add_command('export MASTER_PORT=%r' % PORT) + + # ************** DON'T FORGET THIS *************** + # MUST load the latest NCCL version + cluster.load_modules(['NCCL/2.4.7-1-cuda.10.0']) + + # configure cluster + cluster.per_experiment_nb_nodes = 12 + cluster.per_experiment_nb_gpus = 8 + + cluster.add_slurm_cmd(cmd='ntasks-per-node', value=8, comment='1 task per gpu') + + # submit a script with 9 combinations of hyper params + # (lr=1e-3, layers=16), (lr=1e-3, layers=32), (lr=1e-3, layers=64), ... (lr=1e-1, layers=64) + cluster.optimize_parallel_cluster_gpu( + main, + nb_trials=9, # how many permutations of the grid search to run + job_name='name_for_squeue' + ) + + +The other option is that you generate scripts on your own via a bash command or use another library. + +---------- + +Self-balancing architecture (COMING SOON) +----------------------------------------- + +Here Lightning distributes parts of your module across available GPUs to optimize for speed and memory. diff --git a/docs/source/child_modules.rst b/docs/source/common/child_modules.rst similarity index 68% rename from docs/source/child_modules.rst rename to docs/source/common/child_modules.rst index 4c2d60cc13246e..b0f3fdfc53e9d6 100644 --- a/docs/source/child_modules.rst +++ b/docs/source/common/child_modules.rst @@ -16,6 +16,8 @@ def val_dataloader(): pass + def test_dataloader(): + pass Child Modules ------------- @@ -23,8 +25,8 @@ Research projects tend to test different approaches to the same dataset. This is very easy to do in Lightning with inheritance. For example, imagine we now want to train an Autoencoder to use as a feature extractor for MNIST images. -Recall that `LitMNIST` already defines all the dataloading etc... The only things -that change in the `Autoencoder` model are the init, forward, training, validation and test step. +We are extending our Autoencoder from the `LitMNIST`-module which already defines all the dataloading. +The only things that change in the `Autoencoder` model are the init, forward, training, validation and test step. .. testcode:: @@ -40,32 +42,33 @@ that change in the `Autoencoder` model are the init, forward, training, validati super().__init__() self.encoder = Encoder() self.decoder = Decoder() + self.metric = MSE() def forward(self, x): - generated = self.decoder(x) + return self.encoder(x) def training_step(self, batch, batch_idx): x, _ = batch representation = self.encoder(x) - x_hat = self(representation) + x_hat = self.decoder(representation) - loss = MSE(x, x_hat) + loss = self.metric(x, x_hat) return loss def validation_step(self, batch, batch_idx): - return self._shared_eval(batch, batch_idx, 'val') + self._shared_eval(batch, batch_idx, 'val') def test_step(self, batch, batch_idx): - return self._shared_eval(batch, batch_idx, 'test') + self._shared_eval(batch, batch_idx, 'test') def _shared_eval(self, batch, batch_idx, prefix): - x, y = batch + x, _ = batch representation = self.encoder(x) - x_hat = self(representation) + x_hat = self.decoder(representation) - loss = F.nll_loss(logits, y) - return {f'{prefix}_loss': loss} + loss = self.metric(x, x_hat) + self.log(f'{prefix}_loss', loss) and we can train this using the same trainer @@ -76,7 +79,7 @@ and we can train this using the same trainer trainer = Trainer() trainer.fit(autoencoder) -And remember that the forward method is to define the practical use of a LightningModule. +And remember that the forward method should define the practical use of a LightningModule. In this case, we want to use the `AutoEncoder` to extract image representations .. code-block:: python diff --git a/docs/source/common/debugging.rst b/docs/source/common/debugging.rst new file mode 100644 index 00000000000000..f3faa72f1e95ec --- /dev/null +++ b/docs/source/common/debugging.rst @@ -0,0 +1,151 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + +.. _debugging: + +Debugging +========= + +.. raw:: html + + + +| + +The following are flags that make debugging much easier. + +---------------- + +fast_dev_run +------------ +This flag runs a "unit test" by running n if set to ``n`` (int) else 1 if set to ``True`` training and validation batch(es). +The point is to detect any bugs in the training/validation loop without having to wait for a full epoch to crash. + +(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run` +argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) + +.. testcode:: + + # runs 1 train, val, test batch and program ends + trainer = Trainer(fast_dev_run=True) + + # runs 7 train, val, test batches and program ends + trainer = Trainer(fast_dev_run=7) + +.. note:: + + This argument will disable tuner, checkpoint callbacks, early stopping callbacks, + loggers and logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. + +---------------- + +Inspect gradient norms +---------------------- +Logs (to a logger), the norm of each weight matrix. + +(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.track_grad_norm` +argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) + +.. testcode:: + + # the 2-norm + trainer = Trainer(track_grad_norm=2) + +---------------- + +Log GPU usage +------------- +Logs (to a logger) the GPU usage for each GPU on the master machine. + +(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.log_gpu_memory` +argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) + +.. testcode:: + + trainer = Trainer(log_gpu_memory=True) + +---------------- + +Make model overfit on subset of data +------------------------------------ + +A good debugging technique is to take a tiny portion of your data (say 2 samples per class), +and try to get your model to overfit. If it can't, it's a sign it won't work with large datasets. + +(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_batches` +argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) + +.. testcode:: + + # use only 1% of training data (and use the same training dataloader (with shuffle off) in val and test) + trainer = Trainer(overfit_batches=0.01) + + # similar, but with a fixed 10 batches no matter the size of the dataset + trainer = Trainer(overfit_batches=10) + +With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler +in the training set to turn off shuffle for you. + +---------------- + +Print a summary of your LightningModule +--------------------------------------- +Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule. +By default it only prints the top-level modules. If you want to show all submodules in your network, use the +`'full'` option: + +.. testcode:: + + trainer = Trainer(weights_summary='full') + +You can also display the intermediate input- and output sizes of all your layers by setting the +``example_input_array`` attribute in your LightningModule. It will print a table like this + +.. code-block:: text + + | Name | Type | Params | In sizes | Out sizes + -------------------------------------------------------------- + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] + 2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512] + +when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers. + +See Also: + - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument + - :class:`~pytorch_lightning.core.memory.ModelSummary` + +---------------- + +Shorten epochs +-------------- +Sometimes it's helpful to only use a percentage of your training, val or test data (or a set number of batches). +For example, you can use 20% of the training set and 1% of the validation set. + +On larger datasets like Imagenet, this can help you debug or test a few things faster than waiting for a full epoch. + +.. testcode:: + + # use only 10% of training data and 1% of val data + trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01) + + # use 10 batches of train and 5 batches of val + trainer = Trainer(limit_train_batches=10, limit_val_batches=5) + +---------------- + +Set the number of validation sanity steps +----------------------------------------- +Lightning runs a few steps of validation in the beginning of training. +This avoids crashing in the validation loop sometime deep into a lengthy training loop. + +(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.num_sanity_val_steps` +argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) + +.. testcode:: + + # DEFAULT + trainer = Trainer(num_sanity_val_steps=2) diff --git a/docs/source/common/early_stopping.rst b/docs/source/common/early_stopping.rst new file mode 100644 index 00000000000000..53bafbf116c81a --- /dev/null +++ b/docs/source/common/early_stopping.rst @@ -0,0 +1,99 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + +.. _early_stopping: + +************** +Early stopping +************** + +.. raw:: html + + + +| + +Stopping an epoch early +======================= +You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start` to return ``-1`` when some condition is met. + +If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run. + +---------- + +Early stopping based on metric using the EarlyStopping Callback +=============================================================== +The +:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` +callback can be used to monitor a validation metric and stop the training when no improvement is observed. + +To enable it: + +- Import :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback. +- Log the metric you want to monitor using :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method. +- Init the callback, and set `monitor` to the logged metric of your choice. +- Pass the :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback to the :class:`~pytorch_lightning.trainer.trainer.Trainer` callbacks flag. + +.. code-block:: python + + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + + def validation_step(...): + self.log('val_loss', loss) + + trainer = Trainer(callbacks=[EarlyStopping(monitor='val_loss')]) + +- You can customize the callbacks behaviour by changing its parameters. + +.. testcode:: + + early_stop_callback = EarlyStopping( + monitor='val_accuracy', + min_delta=0.00, + patience=3, + verbose=False, + mode='max' + ) + trainer = Trainer(callbacks=[early_stop_callback]) + +In case you need early stopping in a different part of training, subclass :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` +and change where it is called: + +.. testcode:: + + class MyEarlyStopping(EarlyStopping): + + def on_validation_end(self, trainer, pl_module): + # override this to disable early stopping at the end of val loop + pass + + def on_train_end(self, trainer, pl_module): + # instead, do it at the end of training loop + self._run_early_stopping_check(trainer, pl_module) + +.. note:: + The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback runs + at the end of every validation epoch, + which, under the default configuration, happen after every training epoch. + However, the frequency of validation can be modified by setting various parameters + in the :class:`~pytorch_lightning.trainer.trainer.Trainer`, + for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch` + and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`. + It must be noted that the `patience` parameter counts the number of + validation epochs with no improvement, and not the number of training epochs. + Therefore, with parameters `check_val_every_n_epoch=10` and `patience=3`, the trainer + will perform at least 40 training epochs before being stopped. + +.. seealso:: + - :class:`~pytorch_lightning.trainer.trainer.Trainer` + - :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` + +---------- + +.. seealso:: + - :class:`~pytorch_lightning.trainer.trainer.Trainer` + - :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` diff --git a/docs/source/fast_training.rst b/docs/source/common/fast_training.rst similarity index 60% rename from docs/source/fast_training.rst rename to docs/source/common/fast_training.rst index 208838f58b07c7..2216d234836f24 100644 --- a/docs/source/fast_training.rst +++ b/docs/source/common/fast_training.rst @@ -2,12 +2,15 @@ from pytorch_lightning.trainer.trainer import Trainer +.. _fast_training: Fast Training ============= There are multiple options to speed up different parts of the training by choosing to train on a subset of data. This could be done for speed or debugging purposes. +---------------- + Check validation every n epochs ------------------------------- If you have a small dataset you might want to check validation every n epochs @@ -17,6 +20,8 @@ If you have a small dataset you might want to check validation every n epochs # DEFAULT trainer = Trainer(check_val_every_n_epoch=1) +---------------- + Force training for min or max epochs ------------------------------------ It can be useful to force training for a minimum number of epochs or limit to a max number. @@ -29,12 +34,13 @@ It can be useful to force training for a minimum number of epochs or limit to a # DEFAULT trainer = Trainer(min_epochs=1, max_epochs=1000) +---------------- Set validation check frequency within 1 training epoch ------------------------------------------------------ For large datasets it's often desirable to check validation multiple times within a training loop. -Pass in a float to check that often within 1 training epoch. Pass in an int k to check every k training batches. -Must use an int if using an IterableDataset. +Pass in a float to check that often within 1 training epoch. Pass in an int `k` to check every `k` training batches. +Must use an `int` if using an `IterableDataset`. .. testcode:: @@ -44,29 +50,33 @@ Must use an int if using an IterableDataset. # check every .25 of an epoch trainer = Trainer(val_check_interval=0.25) - # check every 100 train batches (ie: for IterableDatasets or fixed frequency) + # check every 100 train batches (ie: for `IterableDatasets` or fixed frequency) trainer = Trainer(val_check_interval=100) -Use data subset for training, validation and test -------------------------------------------------- +---------------- + +Use data subset for training, validation, and test +-------------------------------------------------- If you don't want to check 100% of the training/validation/test set (for debugging or if it's huge), set these flags. .. testcode:: # DEFAULT trainer = Trainer( - train_percent_check=1.0, - val_percent_check=1.0, - test_percent_check=1.0 + limit_train_batches=1.0, + limit_val_batches=1.0, + limit_test_batches=1.0 ) # check 10%, 20%, 30% only, respectively for training, validation and test set trainer = Trainer( - train_percent_check=0.1, - val_percent_check=0.2, - test_percent_check=0.3 + limit_train_batches=0.1, + limit_val_batches=0.2, + limit_test_batches=0.3 ) -.. note:: ``train_percent_check``, ``val_percent_check`` and ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. ``val_percent_check`` will be ignored if ``fast_dev_run=True``. +If you also pass ``shuffle=True`` to the dataloader, a different random subset of your dataset will be used for each epoch; otherwise the same subset will be used for all epochs. + +.. note:: ``limit_train_batches``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``. -.. note:: If you set ``val_percent_check=0``, validation will be disabled. +.. note:: If you set ``limit_val_batches=0``, validation will be disabled. diff --git a/docs/source/common/hyperparameters.rst b/docs/source/common/hyperparameters.rst new file mode 100644 index 00000000000000..83398c1d63388c --- /dev/null +++ b/docs/source/common/hyperparameters.rst @@ -0,0 +1,290 @@ +.. testsetup:: * + + import torch + from argparse import ArgumentParser, Namespace + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule + import sys + sys.argv = ['foo'] + +Hyperparameters +--------------- +Lightning has utilities to interact seamlessly with the command line ``ArgumentParser`` +and plays well with the hyperparameter optimization framework of your choice. + +---------- + +ArgumentParser +^^^^^^^^^^^^^^ +Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser + +.. testcode:: + + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--layer_1_dim', type=int, default=128) + args = parser.parse_args() + +This allows you to call your program like so: + +.. code-block:: bash + + python trainer.py --layer_1_dim 64 + +---------- + +Argparser Best Practices +^^^^^^^^^^^^^^^^^^^^^^^^ +It is best practice to layer your arguments in three sections. + +1. Trainer args (``gpus``, ``num_nodes``, etc...) +2. Model specific arguments (``layer_dim``, ``num_layers``, ``learning_rate``, etc...) +3. Program arguments (``data_path``, ``cluster_email``, etc...) + +| + +We can do this as follows. First, in your ``LightningModule``, define the arguments +specific to that module. Remember that data splits or data paths may also be specific to +a module (i.e.: if your project has a model that trains on Imagenet and another on CIFAR-10). + +.. testcode:: + + class LitModel(LightningModule): + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("LitModel") + parser.add_argument('--encoder_layers', type=int, default=12) + parser.add_argument('--data_path', type=str, default='/some/path') + return parent_parser + +Now in your main trainer file, add the ``Trainer`` args, the program args, and add the model args + +.. testcode:: + + # ---------------- + # trainer_main.py + # ---------------- + from argparse import ArgumentParser + parser = ArgumentParser() + + # add PROGRAM level args + parser.add_argument('--conda_env', type=str, default='some_name') + parser.add_argument('--notification_email', type=str, default='will@email.com') + + # add model specific args + parser = LitModel.add_model_specific_args(parser) + + # add all the available trainer options to argparse + # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli + parser = Trainer.add_argparse_args(parser) + + args = parser.parse_args() + +Now you can call run your program like so: + +.. code-block:: bash + + python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12 + +Finally, make sure to start the training like so: + +.. code-block:: python + + # init the trainer like this + trainer = Trainer.from_argparse_args(args, early_stopping_callback=...) + + # NOT like this + trainer = Trainer(gpus=hparams.gpus, ...) + + # init the model with Namespace directly + model = LitModel(args) + + # or init the model with all the key-value pairs + dict_args = vars(args) + model = LitModel(**dict_args) + +---------- + +LightningModule hyperparameters +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Often times we train many versions of a model. You might share that model or come back to it a few months later +at which point it is very useful to know how that model was trained (i.e.: what learning rate, neural network, etc...). + +Lightning has a few ways of saving that information for you in checkpoints and yaml files. The goal here is to +improve readability and reproducibility. + +1. The first way is to ask lightning to save the values of anything in the __init__ for you to the checkpoint. This also + makes those values available via `self.hparams`. + + .. code-block:: python + + class LitMNIST(LightningModule): + + def __init__(self, layer_1_dim=128, learning_rate=1e-2, **kwargs): + super().__init__() + # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint + self.save_hyperparameters() + + # equivalent + self.save_hyperparameters('layer_1_dim', 'learning_rate') + + # Now possible to access layer_1_dim from hparams + self.hparams.layer_1_dim + + +2. Sometimes your init might have objects or other parameters you might not want to save. + In that case, choose only a few + + .. code-block:: python + + class LitMNIST(LightningModule): + + def __init__(self, loss_fx, generator_network, layer_1_dim=128 **kwargs): + super().__init__() + self.layer_1_dim = layer_1_dim + self.loss_fx = loss_fx + + # call this to save (layer_1_dim=128) to the checkpoint + self.save_hyperparameters('layer_1_dim') + + # to load specify the other args + model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator()) + + +3. Assign to `self.hparams`. Anything assigned to `self.hparams` will also be saved automatically. + + .. code-block:: python + + # using a argparse.Namespace + class LitMNIST(LightningModule): + def __init__(self, hparams, *args, **kwargs): + super().__init__() + self.hparams = hparams + self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_dim) + self.layer_2 = nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim) + self.layer_3 = nn.Linear(self.hparams.layer_2_dim, 10) + def train_dataloader(self): + return DataLoader(mnist_train, batch_size=self.hparams.batch_size) + + +4. You can also save full objects such as `dict` or `Namespace` to the checkpoint. + + .. code-block:: python + + # using a argparse.Namespace + class LitMNIST(LightningModule): + + def __init__(self, conf, *args, **kwargs): + super().__init__() + self.save_hyperparameters(conf) + + self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_dim) + self.layer_2 = nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim) + self.layer_3 = nn.Linear(self.hparams.layer_2_dim, 10) + + conf = OmegaConf.create(...) + model = LitMNIST(conf) + + # Now possible to access any stored variables from hparams + model.hparams.anything + + + +---------- + +Trainer args +^^^^^^^^^^^^ +To recap, add ALL possible trainer flags to the argparser and init the ``Trainer`` this way + +.. code-block:: python + + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + hparams = parser.parse_args() + + trainer = Trainer.from_argparse_args(hparams) + + # or if you need to pass in callbacks + trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...]) + +---------- + +Multiple Lightning Modules +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We often have multiple Lightning Modules where each one has different arguments. Instead of +polluting the ``main.py`` file, the ``LightningModule`` lets you define arguments for each one. + +.. testcode:: + + class LitMNIST(LightningModule): + + def __init__(self, layer_1_dim, **kwargs): + super().__init__() + self.layer_1 = nn.Linear(28 * 28, layer_1_dim) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("LitMNIST") + parser.add_argument('--layer_1_dim', type=int, default=128) + return parent_parser + +.. testcode:: + + class GoodGAN(LightningModule): + + def __init__(self, encoder_layers, **kwargs): + super().__init__() + self.encoder = Encoder(layers=encoder_layers) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("GoodGAN") + parser.add_argument('--encoder_layers', type=int, default=12) + return parent_parser + + +Now we can allow each model to inject the arguments it needs in the ``main.py`` + +.. code-block:: python + + def main(args): + dict_args = vars(args) + + # pick model + if args.model_name == 'gan': + model = GoodGAN(**dict_args) + elif args.model_name == 'mnist': + model = LitMNIST(**dict_args) + + trainer = Trainer.from_argparse_args(args) + trainer.fit(model) + + if __name__ == '__main__': + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + + # figure out which model to use + parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist') + + # THIS LINE IS KEY TO PULL THE MODEL NAME + temp_args, _ = parser.parse_known_args() + + # let the model add what it wants + if temp_args.model_name == 'gan': + parser = GoodGAN.add_model_specific_args(parser) + elif temp_args.model_name == 'mnist': + parser = LitMNIST.add_model_specific_args(parser) + + args = parser.parse_args() + + # train + main(args) + +and now we can train MNIST or the GAN using the command line interface! + +.. code-block:: bash + + $ python main.py --model_name gan --encoder_layers 24 + $ python main.py --model_name mnist --layer_1_dim 128 diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst new file mode 100644 index 00000000000000..d4a6e3ae94dfc8 --- /dev/null +++ b/docs/source/common/lightning_module.rst @@ -0,0 +1,1382 @@ +.. role:: hidden + :class: hidden-section + +.. _lightning_module: + +LightningModule +=============== +A :class:`~LightningModule` organizes your PyTorch code into 5 sections + +- Computations (init). +- Train loop (training_step) +- Validation loop (validation_step) +- Test loop (test_step) +- Optimizers (configure_optimizers) + +| + +.. raw:: html + + + +| + +Notice a few things. + +1. It's the SAME code. +2. The PyTorch code IS NOT abstracted - just organized. +3. All the other code that's not in the :class:`~LightningModule` + has been automated for you by the trainer. + +| + + .. code-block:: python + + net = Net() + trainer = Trainer() + trainer.fit(net) + +4. There are no .cuda() or .to() calls... Lightning does these for you. + +| + + .. code-block:: python + + # don't do in lightning + x = torch.Tensor(2, 3) + x = x.cuda() + x = x.to(device) + + # do this instead + x = x # leave it alone! + + # or to init a new tensor + new_x = torch.Tensor(2, 3) + new_x = new_x.type_as(x) + +5. There are no samplers for distributed, Lightning also does this for you. + +| + + .. code-block:: python + + # Don't do in Lightning... + data = MNIST(...) + sampler = DistributedSampler(data) + DataLoader(data, sampler=sampler) + + # do this instead + data = MNIST(...) + DataLoader(data) + +6. A :class:`~LightningModule` is a :class:`torch.nn.Module` but with added functionality. Use it as such! + +| + + .. code-block:: python + + net = Net.load_from_checkpoint(PATH) + net.freeze() + out = net(x) + +Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, +(and let's be real, you probably should do anyhow). + +------------ + +Minimal Example +--------------- + +Here are the only required methods. + +.. code-block:: python + + >>> import pytorch_lightning as pl + >>> class LitModel(pl.LightningModule): + ... + ... def __init__(self): + ... super().__init__() + ... self.l1 = nn.Linear(28 * 28, 10) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + ... + ... def training_step(self, batch, batch_idx): + ... x, y = batch + ... y_hat = self(x) + ... loss = F.cross_entropy(y_hat, y) + ... return loss + ... + ... def configure_optimizers(self): + ... return torch.optim.Adam(self.parameters(), lr=0.02) + +Which you can train by doing: + +.. code-block:: python + + train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) + trainer = pl.Trainer() + model = LitModel() + + trainer.fit(model, train_loader) + +The LightningModule has many convenience methods, but the core ones you need to know about are: + +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Name + - Description + * - init + - Define computations here + * - forward + - Use for inference only (separate from training_step) + * - training_step + - the full training loop + * - validation_step + - the full validation loop + * - test_step + - the full test loop + * - configure_optimizers + - define optimizers and LR schedulers + +---------- + +Training +-------- + +Training loop +^^^^^^^^^^^^^ +To add a training loop use the `training_step` method + +.. code-block:: python + + class LitClassifier(pl.LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + return loss + +Under the hood, Lightning does the following (pseudocode): + +.. code-block:: python + + # put model in train mode + model.train() + torch.set_grad_enabled(True) + + losses = [] + for batch in train_dataloader: + # forward + loss = training_step(batch) + losses.append(loss.detach()) + + # clear gradients + optimizer.zero_grad() + + # backward + loss.backward() + + # update parameters + optimizer.step() + + +Training epoch-level metrics +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you want to calculate epoch-level metrics and log them, use the `.log` method + +.. code-block:: python + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + + # logs metrics for each training_step, + # and the average across the epoch, to the progress bar and logger + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + +The `.log` object automatically reduces the requested metrics across the full epoch. +Here's the pseudocode of what it does under the hood: + +.. code-block:: python + + outs = [] + for batch in train_dataloader: + # forward + out = training_step(val_batch) + + # clear gradients + optimizer.zero_grad() + + # backward + loss.backward() + + # update parameters + optimizer.step() + + epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs])) + +Train epoch-level operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you need to do something with all the outputs of each `training_step`, override `training_epoch_end` yourself. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + preds = ... + return {'loss': loss, 'other_stuff': preds} + + def training_epoch_end(self, training_step_outputs): + for pred in training_step_outputs: + # do something + +The matching pseudocode is: + +.. code-block:: python + + outs = [] + for batch in train_dataloader: + # forward + out = training_step(val_batch) + + # clear gradients + optimizer.zero_grad() + + # backward + loss.backward() + + # update parameters + optimizer.step() + + training_epoch_end(outs) + +Training with DataParallel +~~~~~~~~~~~~~~~~~~~~~~~~~~ +When training using a `accelerator` that splits data from each batch across GPUs, sometimes you might +need to aggregate them on the master GPU for processing (dp, or ddp2). + +In this case, implement the `training_step_end` method + +.. code-block:: python + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + pred = ... + return {'loss': loss, 'pred': pred} + + def training_step_end(self, batch_parts): + gpu_0_prediction = batch_parts[0]['pred'] + gpu_1_prediction = batch_parts[1]['pred'] + + # do something with both outputs + return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 + + def training_epoch_end(self, training_step_outputs): + for out in training_step_outputs: + # do something with preds + +The full pseudocode that lighting does under the hood is: + +.. code-block:: python + + outs = [] + for train_batch in train_dataloader: + batches = split_batch(train_batch) + dp_outs = [] + for sub_batch in batches: + # 1 + dp_out = training_step(sub_batch) + dp_outs.append(dp_out) + + # 2 + out = training_step_end(dp_outs) + outs.append(out) + + # do something with the outputs for all batches + # 3 + training_epoch_end(outs) + +------------------ + +Validation loop +^^^^^^^^^^^^^^^ +To add a validation loop, override the `validation_step` method of the :class:`~LightningModule`: + +.. code-block:: python + + class LitModel(pl.LightningModule): + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + self.log('val_loss', loss) + +Under the hood, Lightning does the following: + +.. code-block:: python + + # ... + for batch in train_dataloader: + loss = model.training_step() + loss.backward() + # ... + + if validate_at_some_point: + # disable grads + batchnorm + dropout + torch.set_grad_enabled(False) + model.eval() + + # ----------------- VAL LOOP --------------- + for val_batch in model.val_dataloader: + val_out = model.validation_step(val_batch) + # ----------------- VAL LOOP --------------- + + # enable grads + batchnorm + dropout + torch.set_grad_enabled(True) + model.train() + +Validation epoch-level metrics +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you need to do something with all the outputs of each `validation_step`, override `validation_epoch_end`. + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + pred = ... + return pred + + def validation_epoch_end(self, validation_step_outputs): + for pred in validation_step_outputs: + # do something with a pred + +Validating with DataParallel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When training using a `accelerator` that splits data from each batch across GPUs, sometimes you might +need to aggregate them on the master GPU for processing (dp, or ddp2). + +In this case, implement the `validation_step_end` method + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + pred = ... + return {'loss': loss, 'pred': pred} + + def validation_step_end(self, batch_parts): + gpu_0_prediction = batch_parts.pred[0]['pred'] + gpu_1_prediction = batch_parts.pred[1]['pred'] + + # do something with both outputs + return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 + + def validation_epoch_end(self, validation_step_outputs): + for out in validation_step_outputs: + # do something with preds + +The full pseudocode that lighting does under the hood is: + +.. code-block:: python + + outs = [] + for batch in dataloader: + batches = split_batch(batch) + dp_outs = [] + for sub_batch in batches: + # 1 + dp_out = validation_step(sub_batch) + dp_outs.append(dp_out) + + # 2 + out = validation_step_end(dp_outs) + outs.append(out) + + # do something with the outputs for all batches + # 3 + validation_epoch_end(outs) + +---------------- + +Test loop +^^^^^^^^^ +The process for adding a test loop is the same as the process for adding a validation loop. Please refer to +the section above for details. + +The only difference is that the test loop is only called when `.test()` is used: + +.. code-block:: python + + model = Model() + trainer = Trainer() + trainer.fit() + + # automatically loads the best weights for you + trainer.test(model) + +There are two ways to call `test()`: + +.. code-block:: python + + # call after training + trainer = Trainer() + trainer.fit(model) + + # automatically auto-loads the best weights + trainer.test(test_dataloaders=test_dataloader) + + # or call with pretrained model + model = MyLightningModule.load_from_checkpoint(PATH) + trainer = Trainer() + trainer.test(model, test_dataloaders=test_dataloader) + +---------- + +Inference +--------- +For research, LightningModules are best structured as systems. + +.. code-block:: python + + import pytorch_lightning as pl + import torch + from torch import nn + + class Autoencoder(pl.LightningModule): + + def __init__(self, latent_dim=2): + super().__init__() + self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) + self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) + + def training_step(self, batch, batch_idx): + x, _ = batch + + # encode + x = x.view(x.size(0), -1) + z = self.encoder(x) + + # decode + recons = self.decoder(z) + + # reconstruction + reconstruction_loss = nn.functional.mse_loss(recons, x) + return reconstruction_loss + + def validation_step(self, batch, batch_idx): + x, _ = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + recons = self.decoder(z) + reconstruction_loss = nn.functional.mse_loss(recons, x) + self.log('val_reconstruction', reconstruction_loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.0002) + +Which can be trained like this: + +.. code-block:: python + + autoencoder = Autoencoder() + trainer = pl.Trainer(gpus=1) + trainer.fit(autoencoder, train_dataloader, val_dataloader) + +This simple model generates examples that look like this (the encoders and decoders are too weak) + +.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/ae_docs.png + :width: 300 + +The methods above are part of the lightning interface: + +- training_step +- validation_step +- test_step +- configure_optimizers + +Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code. + +.. code-block:: python + + class Autoencoder(pl.LightningModule): + + def __init__(self, latent_dim=2): + super().__init__() + self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) + self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log('val_loss', loss) + + def shared_step(self, batch): + x, _ = batch + + # encode + x = x.view(x.size(0), -1) + z = self.encoder(x) + + # decode + recons = self.decoder(z) + + # loss + return nn.functional.mse_loss(recons, x) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.0002) + +We create a new method called `shared_step` that all loops can use. This method name is arbitrary and NOT reserved. + +Inference in research +^^^^^^^^^^^^^^^^^^^^^ +In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule. + +.. code-block:: python + + class Autoencoder(pl.LightningModule): + def forward(self, x): + return self.decoder(x) + +The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, +such as text generation: + +.. code-block:: python + + class Seq2Seq(pl.LightningModule): + + def forward(self, x): + embeddings = self(x) + hidden_states = self.encoder(embeddings) + for h in hidden_states: + # decode + ... + return decoded + +Inference in production +^^^^^^^^^^^^^^^^^^^^^^^ +For cases like production, you might want to iterate different models inside a LightningModule. + +.. code-block:: python + + import pytorch_lightning as pl + from pytorch_lightning.metrics import functional as FM + + class ClassificationTask(pl.LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.cross_entropy(y_hat, y) + acc = FM.accuracy(y_hat, y) + + metrics = {'val_acc': acc, 'val_loss': loss} + self.log_dict(metrics) + return metrics + + def test_step(self, batch, batch_idx): + metrics = self.validation_step(batch, batch_idx) + metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']} + self.log_dict(metrics) + + def configure_optimizers(self): + return torch.optim.Adam(self.model.parameters(), lr=0.02) + +Then pass in any arbitrary model to be fit with this task + +.. code-block:: python + + for model in [resnet50(), vgg16(), BidirectionalRNN()]: + task = ClassificationTask(model) + + trainer = Trainer(gpus=2) + trainer.fit(task, train_dataloader, val_dataloader) + +Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL. + +.. code-block:: python + + class GANTask(pl.LightningModule): + + def __init__(self, generator, discriminator): + super().__init__() + self.generator = generator + self.discriminator = discriminator + ... + +When used like this, the model can be separated from the Task and thus used in production without needing to keep it in +a `LightningModule`. + +- You can export to onnx. +- Or trace using Jit. +- or run in the python runtime. + +.. code-block:: python + + task = ClassificationTask(model) + + trainer = Trainer(gpus=2) + trainer.fit(task, train_dataloader, val_dataloader) + + # use model after training or load weights and drop into the production system + model.eval() + y_hat = model(x) + +----------- + +LightningModule API +------------------- + +Methods +^^^^^^^ + +configure_callbacks +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_callbacks + :noindex: + +configure_optimizers +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_optimizers + :noindex: + +forward +~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.forward + :noindex: + +freeze +~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.freeze + :noindex: + +log +~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.log + :noindex: + +log_dict +~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict + :noindex: + +print +~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.print + :noindex: + +save_hyperparameters +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters + :noindex: + +test_step +~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_step + :noindex: + +test_step_end +~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_step_end + :noindex: + +test_epoch_end +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_epoch_end + :noindex: + +to_onnx +~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.to_onnx + :noindex: + +to_torchscript +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.to_torchscript + :noindex: + +training_step +~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step + :noindex: + +training_step_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_step_end + :noindex: + +training_epoch_end +~~~~~~~~~~~~~~~~~~ +.. automethod:: pytorch_lightning.core.lightning.LightningModule.training_epoch_end + :noindex: + +unfreeze +~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.unfreeze + :noindex: + +validation_step +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step + :noindex: + +validation_step_end +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_step_end + :noindex: + +validation_epoch_end +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.validation_epoch_end + :noindex: + +write_prediction +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.write_prediction + :noindex: + +write_prediction_dict +~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.write_prediction_dict + :noindex: + +------------ + +Properties +^^^^^^^^^^ +These are properties available in a LightningModule. + +----------- + +current_epoch +~~~~~~~~~~~~~ +The current epoch + +.. code-block:: python + + def training_step(...): + if self.current_epoch == 0: + +------------- + +device +~~~~~~ +The device the module is on. Use it to keep your code device agnostic + +.. code-block:: python + + def training_step(...): + z = torch.rand(2, 3, device=self.device) + +------------- + +global_rank +~~~~~~~~~~~ +The global_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You +normally do not need to use this property + +Global rank refers to the index of that GPU across ALL GPUs. For example, if using 10 machines, each with 4 GPUs, +the 4th GPU on the 10th machine has global_rank = 39 + +------------- + +global_step +~~~~~~~~~~~ +The current step (does not reset each epoch) + +.. code-block:: python + + def training_step(...): + self.logger.experiment.log_image(..., step=self.global_step) + +------------- + +hparams +~~~~~~~ +The arguments saved by calling ``save_hyperparameters`` passed through ``__init__()`` + could be accessed by the ``hparams`` attribute. + +.. code-block:: python + + def __init__(self, learning_rate): + self.save_hyperparameters() + + def configure_optimizers(self): + return Adam(self.parameters(), lr=self.hparams.learning_rate) + +-------------- + +logger +~~~~~~ +The current logger being used (tensorboard or other supported logger) + +.. code-block:: python + + def training_step(...): + # the generic logger (same no matter if tensorboard or other supported logger) + self.logger + + # the particular logger + tensorboard_logger = self.logger.experiment + +-------------- + +local_rank +~~~~~~~~~~~ +The local_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You +normally do not need to use this property + +Local rank refers to the rank on that machine. For example, if using 10 machines, the GPU at index 0 on each machine +has local_rank = 0. + + +----------- + +precision +~~~~~~~~~ +The type of precision used: + +.. code-block:: python + + def training_step(...): + if self.precision == 16: + +------------ + +trainer +~~~~~~~ +Pointer to the trainer + +.. code-block:: python + + def training_step(...): + max_steps = self.trainer.max_steps + any_flag = self.trainer.any_flag + +------------ + +use_amp +~~~~~~~ +True if using Automatic Mixed Precision (AMP) + +------------ + +use_ddp +~~~~~~~ +True if using ddp + +------------ + +use_ddp2 +~~~~~~~~ +True if using ddp2 + +------------ + +use_dp +~~~~~~ +True if using dp + +------------ + +use_tpu +~~~~~~~ +True if using TPUs + +-------------- + +automatic_optimization +~~~~~~~~~~~~~~~~~~~~~~ +When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used. + +.. code-block:: python + + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt = self.optimizers(use_pl_optimizer=True) + + loss = ... + opt.zero_grad() + self.manual_backward(loss) + opt.step() + +This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. + +.. code-block:: python + + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # access your optimizers with use_pl_optimizer=False. Default is True + opt_a, opt_b = self.optimizers(use_pl_optimizer=True) + + gen_loss = ... + opt_a.zero_grad() + self.manual_backward(gen_loss) + opt_a.step() + + disc_loss = ... + opt_b.zero_grad() + self.manual_backward(disc_loss) + opt_b.step() + +-------------- + +example_input_array +~~~~~~~~~~~~~~~~~~~ +Set and access example_input_array which is basically a single batch. + +.. code-block:: python + + def __init__(self): + self.example_input_array = ... + self.generator = ... + + def on_train_epoch_end(...): + # generate some images using the example_input_array + gen_images = self.generator(self.example_input_array) + +-------------- + +datamodule +~~~~~~~~~~ +Set or access your datamodule. + +.. code-block:: python + + def configure_optimizers(self): + num_training_samples = len(self.datamodule.train_dataloader()) + ... + +-------------- + +model_size +~~~~~~~~~~ +Get the model file size (in megabytes) using ``self.model_size`` inside LightningModule. + +-------------- + +Hooks +^^^^^ +This is the pseudocode to describe how all the hooks are called during a call to ``.fit()``. + +.. code-block:: python + + def fit(...): + if global_rank == 0: + # prepare data is called on GLOBAL_ZERO only + prepare_data() + + configure_callbacks() + + on_fit_start() + + for gpu/tpu in gpu/tpus: + train_on_device(model.copy()) + + on_fit_end() + + def train_on_device(model): + # setup is called PER DEVICE + setup() + configure_optimizers() + on_pretrain_routine_start() + + for epoch in epochs: + train_loop() + + teardown() + + def train_loop(): + on_epoch_start() + on_train_epoch_start() + train_outs = [] + for train_batch in train_dataloader(): + on_train_batch_start() + + # ----- train_step methods ------- + out = training_step(batch) + train_outs.append(out) + + loss = out.loss + + on_before_zero_grad() + optimizer_zero_grad() + + backward() + on_after_backward() + + optimizer_step() + + on_train_batch_end(out) + + if should_check_val: + val_loop() + + # end training epoch + outs = training_epoch_end(outs) + on_train_epoch_end(outs) + on_epoch_end() + + def val_loop(): + model.eval() + torch.set_grad_enabled(False) + + on_epoch_start() + on_validation_epoch_start() + val_outs = [] + for val_batch in val_dataloader(): + on_validation_batch_start() + + # -------- val step methods ------- + out = validation_step(val_batch) + val_outs.append(out) + + on_validation_batch_end(out) + + validation_epoch_end(val_outs) + on_validation_epoch_end() + on_epoch_end() + + # set up for train + model.train() + torch.set_grad_enabled(True) + +backward +~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.backward + :noindex: + +get_progress_bar_dict +~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict + :noindex: + +manual_backward +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward + :noindex: + + +on_after_backward +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward + :noindex: + +on_before_zero_grad +~~~~~~~~~~~~~~~~~~~ +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad + :noindex: + +on_fit_start +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_start + :noindex: + +on_fit_end +~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_end + :noindex: + + +on_load_checkpoint +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint + :noindex: + +on_save_checkpoint +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint + :noindex: + +on_train_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start + :noindex: + +on_train_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end + :noindex: + +on_validation_start +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start + :noindex: + +on_validation_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end + :noindex: + +on_pretrain_routine_start +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_start + :noindex: + +on_pretrain_routine_end +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_end + :noindex: + +on_test_batch_start +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_start + :noindex: + +on_test_batch_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_end + :noindex: + +on_test_epoch_start +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_start + :noindex: + +on_test_epoch_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end + :noindex: + +on_test_end +~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end + :noindex: + +on_train_batch_start +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start + :noindex: + +on_train_batch_end +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end + :noindex: + +on_epoch_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start + :noindex: + +on_epoch_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end + :noindex: + +on_train_epoch_start +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_start + :noindex: + +on_train_epoch_end +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_end + :noindex: + +on_validation_batch_start +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_start + :noindex: + +on_validation_batch_end +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_end + :noindex: + +on_validation_epoch_start +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_start + :noindex: + +on_validation_epoch_end +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end + :noindex: + +on_post_move_to_device +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device + :noindex: + +on_validation_model_eval +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval + :noindex: + +on_validation_model_train +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train + :noindex: + +on_test_model_eval +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval + :noindex: + +on_test_model_train +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train + :noindex: + +optimizer_step +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_step + :noindex: + +optimizer_zero_grad +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizer_zero_grad + :noindex: + +prepare_data +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.prepare_data + :noindex: + +setup +~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup + :noindex: + +tbptt_split_batch +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch + :noindex: + +teardown +~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown + :noindex: + +train_dataloader +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader + :noindex: + +val_dataloader +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader + :noindex: + +test_dataloader +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader + :noindex: + +transfer_batch_to_device +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device + :noindex: + +on_before_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer + :noindex: + +on_after_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer + :noindex: diff --git a/docs/source/experiment_logging.rst b/docs/source/common/loggers.rst similarity index 83% rename from docs/source/experiment_logging.rst rename to docs/source/common/loggers.rst index 772efcfc13bc53..c6c5f0d8653c78 100644 --- a/docs/source/experiment_logging.rst +++ b/docs/source/common/loggers.rst @@ -3,12 +3,33 @@ from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule +.. _loggers: -Experiment Logging -================== +******* +Loggers +******* + +Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...). TensorBoard is used by default, +but you can pass to the :class:`~pytorch_lightning.trainer.trainer.Trainer` any combination of the following loggers. + +.. note:: + + All loggers log by default to `os.getcwd()`. To change the path without creating a logger set + `Trainer(default_root_dir='/your/path/to/save/checkpoints')` + +Read more about :doc:`logging <../extensions/logging>` options. + +To log arbitrary artifacts like images or audio samples use the `trainer.log_dir` property to resolve +the path. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + img = ... + log_image(img, self.trainer.log_dir) Comet.ml -^^^^^^^^ +======== `Comet.ml `_ is a third-party logger. To use :class:`~pytorch_lightning.loggers.CometLogger` as your logger do the following. @@ -47,8 +68,10 @@ The :class:`~pytorch_lightning.loggers.CometLogger` is available anywhere except .. seealso:: :class:`~pytorch_lightning.loggers.CometLogger` docs. +---------------- + MLflow -^^^^^^ +====== `MLflow `_ is a third-party logger. To use :class:`~pytorch_lightning.loggers.MLFlowLogger` as your logger do the following. @@ -60,7 +83,7 @@ First, install the package: Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: -.. testcode:: +.. code-block:: python from pytorch_lightning.loggers import MLFlowLogger mlf_logger = MLFlowLogger( @@ -72,8 +95,10 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. .. seealso:: :class:`~pytorch_lightning.loggers.MLFlowLogger` docs. +---------------- + Neptune.ai -^^^^^^^^^^ +========== `Neptune.ai `_ is a third-party logger. To use :class:`~pytorch_lightning.loggers.NeptuneLogger` as your logger do the following. @@ -88,6 +113,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. .. testcode:: from pytorch_lightning.loggers import NeptuneLogger + neptune_logger = NeptuneLogger( api_key='ANONYMOUS', # replace with your own project_name='shared/pytorch-lightning-integration', @@ -110,50 +136,10 @@ The :class:`~pytorch_lightning.loggers.NeptuneLogger` is available anywhere exce .. seealso:: :class:`~pytorch_lightning.loggers.NeptuneLogger` docs. -allegro.ai TRAINS -^^^^^^^^^^^^^^^^^ - -`allegro.ai `_ is a third-party logger. -To use :class:`~pytorch_lightning.loggers.TrainsLogger` as your logger do the following. -First, install the package: - -.. code-block:: bash - - pip install trains - -Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: - -.. testcode:: - - from pytorch_lightning.loggers import TrainsLogger - trains_logger = TrainsLogger( - project_name='examples', - task_name='pytorch lightning test', - ) - trainer = Trainer(logger=trains_logger) - -.. testoutput:: - :options: +ELLIPSIS, +NORMALIZE_WHITESPACE - :hide: - - TRAINS Task: ... - TRAINS results page: ... - -The :class:`~pytorch_lightning.loggers.TrainsLogger` is available anywhere in your -:class:`~pytorch_lightning.core.lightning.LightningModule`. - -.. testcode:: - - class MyModule(LightningModule): - def __init__(self): - some_img = fake_image() - self.logger.experiment.log_image('debug', 'generated_image_0', some_img, 0) - -.. seealso:: - :class:`~pytorch_lightning.loggers.TrainsLogger` docs. +---------------- Tensorboard -^^^^^^^^^^^ +=========== To use `TensorBoard `_ as your logger do the following. @@ -176,8 +162,10 @@ The :class:`~pytorch_lightning.loggers.TensorBoardLogger` is available anywhere .. seealso:: :class:`~pytorch_lightning.loggers.TensorBoardLogger` docs. +---------------- + Test Tube -^^^^^^^^^ +========= `Test Tube `_ is a `TensorBoard `_ logger but with nicer file structure. @@ -190,7 +178,7 @@ First, install the package: Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: -.. testcode:: +.. code-block:: python from pytorch_lightning.loggers import TestTubeLogger logger = TestTubeLogger('tb_logs', name='my_model') @@ -209,8 +197,10 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc .. seealso:: :class:`~pytorch_lightning.loggers.TestTubeLogger` docs. +---------------- + Weights and Biases -^^^^^^^^^^^^^^^^^^ +================== `Weights and Biases `_ is a third-party logger. To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following. @@ -222,10 +212,10 @@ First, install the package: Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: -.. testcode:: +.. code-block:: python from pytorch_lightning.loggers import WandbLogger - wandb_logger = WandbLogger() + wandb_logger = WandbLogger(offline=True) trainer = Trainer(logger=wandb_logger) The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your @@ -243,19 +233,21 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except .. seealso:: :class:`~pytorch_lightning.loggers.WandbLogger` docs. +---------------- + Multiple Loggers -^^^^^^^^^^^^^^^^ +================ Lightning supports the use of multiple loggers, just pass a list to the :class:`~pytorch_lightning.trainer.trainer.Trainer`. -.. testcode:: +.. code-block:: python from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger logger1 = TensorBoardLogger('tb_logs', name='my_model') logger2 = TestTubeLogger('tb_logs', name='my_model') trainer = Trainer(logger=[logger1, logger2]) - + The loggers are available as a list anywhere except ``__init__`` in your :class:`~pytorch_lightning.core.lightning.LightningModule`. diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst new file mode 100644 index 00000000000000..422302ea8987e7 --- /dev/null +++ b/docs/source/common/optimizers.rst @@ -0,0 +1,489 @@ +.. _optimizers: + +************ +Optimization +************ + +Lightning offers two modes for managing the optimization process: + +- automatic optimization (AutoOpt) +- manual optimization + +For the majority of research cases, **automatic optimization** will do the right thing for you and it is what +most users should use. + +For advanced/expert users who want to do esoteric optimization schedules or techniques, use **manual optimization**. + +------ + +Manual optimization +=================== +For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable +to manually manage the optimization process. To do so, do the following: + +* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function +* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``. + +.. testcode:: python + + from pytorch_lightning import LightningModule + + class MyModel(LightningModule): + + def __init__(self): + super().__init__() + # Important: This property activate ``manual optimization`` for your model + self.automatic_optimization = False + + def training_step(batch, batch_idx): + opt = self.optimizers() + loss = self.compute_loss(batch) + self.manual_backward(loss) + +.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. + +.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertise. + +.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such. + +.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you. + +.. code-block:: python + + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + + loss = self.compute_loss(batch) + self.manual_backward(loss) + + # accumulate gradient batches + if batch_idx % 2 == 0: + opt.step() + opt.zero_grad() + +.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward``, ``zero_grad`` and ``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs `_. + +Here is the same example as above using a ``closure``. + +.. testcode:: python + + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + + def closure(): + # Only zero_grad on the first batch to accumulate gradients + is_first_batch_to_accumulate = batch_idx % 2 == 0 + if is_first_batch_to_accumulate: + opt.zero_grad() + + loss = self.compute_loss(batch) + self.manual_backward(loss) + return loss + + opt.step(closure=closure) + +.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``. + +.. testcode:: python + + import torch + from torch import Tensor + from pytorch_lightning import LightningModule + + class SimpleGAN(LightningModule): + + def __init__(self): + super().__init__() + self.G = Generator() + self.D = Discriminator() + + # Important: This property activate ``manual optimization`` for this model + self.automatic_optimization = False + + def sample_z(self, n) -> Tensor: + sample = self._Z.sample((n,)) + return sample + + def sample_G(self, n) -> Tensor: + z = self.sample_z(n) + return self.G(z) + + def training_step(self, batch, batch_idx): + # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + g_opt, d_opt = self.optimizers() + + X, _ = batch + batch_size = X.shape[0] + + real_label = torch.ones((batch_size, 1), device=self.device) + fake_label = torch.zeros((batch_size, 1), device=self.device) + + g_X = self.sample_G(batch_size) + + ########################### + # Optimize Discriminator # + ########################### + d_opt.zero_grad() + d_x = self.D(X) + errD_real = self.criterion(d_x, real_label) + + d_z = self.D(g_X.detach()) + errD_fake = self.criterion(d_z, fake_label) + + errD = (errD_real + errD_fake) + + self.manual_backward(errD) + d_opt.step() + + ####################### + # Optimize Generator # + ####################### + g_opt.zero_grad() + + d_z = self.D(g_X) + errG = self.criterion(d_z, real_label) + + self.manual_backward(errG) + g_opt.step() + + self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True) + + def configure_optimizers(self): + g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5) + d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5) + return g_opt, d_opt + +.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting. + +Here is an explanation of what it does: + +Considering the current optimizer as A and all other optimizers as B. +Toggling means that all parameters from B exclusive to A will have their ``requires_grad`` attribute set to ``False``. Their original state will be restored when exiting the context manager. + +When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. +Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed. + + +Here is an example for advanced use-case. + +.. testcode:: python + + # Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus. + + class SimpleGAN(LightningModule): + + ... + + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + g_opt, d_opt = self.optimizers() + + X, _ = batch + X.requires_grad = True + batch_size = X.shape[0] + + real_label = torch.ones((batch_size, 1), device=self.device) + fake_label = torch.zeros((batch_size, 1), device=self.device) + + accumulated_grad_batches = batch_idx % 2 == 0 + + g_X = self.sample_G(batch_size) + + ########################### + # Optimize Discriminator # + ########################### + with d_opt.toggle_model(sync_grad=accumulated_grad_batches): + d_x = self.D(X) + errD_real = self.criterion(d_x, real_label) + + d_z = self.D(g_X.detach()) + errD_fake = self.criterion(d_z, fake_label) + + errD = (errD_real + errD_fake) + + self.manual_backward(errD) + if accumulated_grad_batches: + d_opt.step() + d_opt.zero_grad() + + ####################### + # Optimize Generator # + ####################### + with g_opt.toggle_model(sync_grad=accumulated_grad_batches): + d_z = self.D(g_X) + errG = self.criterion(d_z, real_label) + + self.manual_backward(errG) + if accumulated_grad_batches: + g_opt.step() + g_opt.zero_grad() + + self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True) + +------ + +Automatic optimization +====================== +With Lightning most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()`` +since Lightning automates that for you. + +.. warning:: + Before 1.2.2, ``.zero_grad()`` was called after ``.backward()`` and ``.step()`` internally. + From 1.2.2, Lightning calls ``.zero_grad()`` before ``.backward()``. + +Under the hood Lightning does the following: + +.. code-block:: python + + for epoch in epochs: + for batch in data: + loss = model.training_step(batch, batch_idx, ...) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for lr_scheduler in lr_schedulers: + lr_scheduler.step() + +In the case of multiple optimizers, Lightning does the following: + +.. code-block:: python + + for epoch in epochs: + for batch in data: + for opt in optimizers: + loss = model.training_step(batch, batch_idx, optimizer_idx) + opt.zero_grad() + loss.backward() + opt.step() + + for lr_scheduler in lr_schedulers: + lr_scheduler.step() + + +Learning rate scheduling +------------------------ +Every optimizer you use can be paired with any `Learning Rate Scheduler `_. +In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers`` method: + +.. testcode:: + + # no LR scheduler + def configure_optimizers(self): + return Adam(...) + + # Adam + LR scheduler + def configure_optimizers(self): + optimizer = Adam(...) + scheduler = LambdaLR(optimizer, ...) + return [optimizer], [scheduler] + + # Two optimizers each with a scheduler + def configure_optimizers(self): + optimizer1 = Adam(...) + optimizer2 = SGD(...) + scheduler1 = LambdaLR(optimizer1, ...) + scheduler2 = LambdaLR(optimizer2, ...) + return [optimizer1, optimizer2], [scheduler1, scheduler2] + +When there are schedulers in which the ``.step()`` method is conditioned on a metric value (for example the +:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler), Lightning requires that the output +from ``configure_optimizers`` should be dicts, one for each optimizer, with the keyword ``monitor`` +set to metric that the scheduler should be conditioned on. + +.. testcode:: + + # The ReduceLROnPlateau scheduler requires a monitor + def configure_optimizers(self): + return { + 'optimizer': Adam(...), + 'lr_scheduler': ReduceLROnPlateau(optimizer, ...), + 'monitor': 'metric_to_track' + } + + # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler + def configure_optimizers(self): + optimizer1 = Adam(...) + optimizer2 = SGD(...) + scheduler1 = ReduceLROnPlateau(optimizer1, ...) + scheduler2 = LambdaLR(optimizer2, ...) + return ( + {'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'}, + {'optimizer': optimizer2, 'lr_scheduler': scheduler2}, + ) + +.. note:: + Metrics can be made availble to condition on by simply logging it using ``self.log('metric_to_track', metric_val)`` + in your lightning module. + +By default, all schedulers will be called after each epoch ends. To change this behaviour, a scheduler configuration should be +returned as a dict which can contain the following keywords: + +* ``scheduler`` (required): the actual scheduler object +* ``monitor`` (optional): metric to condition +* ``interval`` (optional): either ``epoch`` (default) for stepping after each epoch ends or ``step`` for stepping + after each optimization step +* ``frequency`` (optional): how many epochs/steps should pass between calls to ``scheduler.step()``. Default is 1, + corresponding to updating the learning rate after every epoch/step. +* ``strict`` (optional): if set to ``True`` will enforce that value specified in ``monitor`` is available while trying + to call ``scheduler.step()``, and stop training if not found. If ``False`` will only give a warning and continue training + (without calling the scheduler). +* ``name`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the + learning rate progress, this keyword can be used to specify a specific name the learning rate should be logged as. + +.. testcode:: + + # Same as the above example with additional params passed to the first scheduler + # In this case the ReduceLROnPlateau will step after every 10 processed batches + def configure_optimizers(self): + optimizers = [Adam(...), SGD(...)] + schedulers = [ + { + 'scheduler': ReduceLROnPlateau(optimizers[0], ...), + 'monitor': 'metric_to_track', + 'interval': 'step', + 'frequency': 10, + 'strict': True, + }, + LambdaLR(optimizers[1], ...) + ] + return optimizers, schedulers + +---------- + +Use multiple optimizers (like GANs) +----------------------------------- +To use multiple optimizers return two or more optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers` + +.. testcode:: + + # one optimizer + def configure_optimizers(self): + return Adam(...) + + # two optimizers, no schedulers + def configure_optimizers(self): + return Adam(...), SGD(...) + + # Two optimizers, one scheduler for adam only + def configure_optimizers(self): + return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'} + +Lightning will call each optimizer sequentially: + +.. code-block:: python + + for epoch in epochs: + for batch in data: + for opt in optimizers: + loss = train_step(batch, batch_idx, optimizer_idx) + opt.zero_grad() + loss.backward() + opt.step() + + for lr_scheduler in lr_schedulers: + lr_scheduler.step() + +---------- + +Step optimizers at arbitrary intervals +-------------------------------------- +To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling, +override the :meth:`optimizer_step` function. + +For example, here step optimizer A every 2 batches and optimizer B every 4 batches + +.. testcode:: + + def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): + optimizer.zero_grad() + + # Alternating schedule for optimizer steps (ie: GANs) + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + # update generator opt every 2 steps + if optimizer_idx == 0: + if batch_nb % 2 == 0 : + optimizer.step(closure=closure) + + # update discriminator opt every 4 steps + if optimizer_idx == 1: + if batch_nb % 4 == 0 : + optimizer.step(closure=closure) + +Here we add a learning-rate warm up + +.. testcode:: + + # learning rate warm-up + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * self.hparams.learning_rate + + # update params + optimizer.step(closure=closure) + +.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches and much more ... + +.. testcode:: + + # function hook in LightningModule + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + optimizer.step(closure=closure) + +.. note:: To access your wrapped Optimizer from ``LightningOptimizer``, do as follow. + +.. testcode:: + + # function hook in LightningModule + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + + # `optimizer is a ``LightningOptimizer`` wrapping the optimizer. + # To access it, do as follow: + optimizer = optimizer.optimizer + + # run step. However, it won't work on TPU, AMP, etc... + optimizer.step(closure=closure) + + +---------- + +Using the closure functions for optimization +-------------------------------------------- + +When using optimization schemes such as LBFGS, the `second_order_closure` needs to be enabled. By default, this function is defined by wrapping the `training_step` and the backward steps as follows + +.. warning:: + Before 1.2.2, ``.zero_grad()`` was called outside the closure internally. + From 1.2.2, the closure calls ``.zero_grad()`` inside, so there is no need to define your own closure + when using similar optimizers to :class:`torch.optim.LBFGS` which requires reevaluation of the loss with the closure in ``optimizer.step()``. + +.. testcode:: + + def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden): + # Model training step on a given batch + result = pl_module.training_step(split_batch, batch_idx, opt_idx, hidden) + + # Model backward pass + pl_module.backward(result, optimizer, opt_idx) + + # on_after_backward callback + pl_module.on_after_backward(result.training_step_output, batch_idx, result.loss) + + return result + + # This default `second_order_closure` function can be enabled by passing it directly into the `optimizer.step` + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + # update params + optimizer.step(second_order_closure) diff --git a/docs/source/common/production_inference.rst b/docs/source/common/production_inference.rst new file mode 100644 index 00000000000000..9fad8bdb402832 --- /dev/null +++ b/docs/source/common/production_inference.rst @@ -0,0 +1,48 @@ +.. _production_inference: + +Inference in Production +======================= +PyTorch Lightning eases the process of deploying models into production. + + +Exporting to ONNX +----------------- +PyTorch Lightning provides a handy function to quickly export your model to ONNX format, which allows the model to be independent of PyTorch and run on an ONNX Runtime. + +To export your model to ONNX format call the ``to_onnx`` function on your Lightning Module with the filepath and input_sample. + +.. code-block:: python + + filepath = 'model.onnx' + model = SimpleModel() + input_sample = torch.randn((1, 64)) + model.to_onnx(filepath, input_sample, export_params=True) + +You can also skip passing the input sample if the ` example_input_array ` property is specified in your LightningModule. + +Once you have the exported model, you can run it on your ONNX runtime in the following way: + +.. code-block:: python + + ort_session = onnxruntime.InferenceSession(filepath) + input_name = ort_session.get_inputs()[0].name + ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)} + ort_outs = ort_session.run(None, ort_inputs) + + +Exporting to TorchScript +------------------------ + +TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments. +The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` +that returns a scripted module which you can save or directly use. + +.. code-block:: python + + model = SimpleModel() + script = model.to_torchscript() + + # save for use in production environment + torch.jit.save(script, "model.pt") + +It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. diff --git a/docs/source/single_gpu.rst b/docs/source/common/single_gpu.rst similarity index 83% rename from docs/source/single_gpu.rst rename to docs/source/common/single_gpu.rst index c6fa1b9af9bbc7..14e0486fa7e21e 100644 --- a/docs/source/single_gpu.rst +++ b/docs/source/common/single_gpu.rst @@ -2,8 +2,10 @@ from pytorch_lightning.trainer.trainer import Trainer +.. _single_gpu: + Single GPU Training -==================== +=================== Make sure you are running on a machine that has at least one GPU. Lightning handles all the NVIDIA flags for you, there's no need to set them yourself. @@ -11,4 +13,4 @@ there's no need to set them yourself. :skipif: torch.cuda.device_count() < 1 # train on 1 GPU (using dp mode) - trainer = Trainer(gpus=1) \ No newline at end of file + trainer = Trainer(gpus=1) diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst new file mode 100644 index 00000000000000..4c9e9a6061977a --- /dev/null +++ b/docs/source/common/test_set.rst @@ -0,0 +1,104 @@ +.. _test_set: + +Test set +======== +Lightning forces the user to run the test set separately to make sure it isn't evaluated by mistake. +Testing is performed using the ``trainer`` object's ``.test()`` method. + +.. automethod:: pytorch_lightning.trainer.Trainer.test + :noindex: + +---------- + +Test after fit +-------------- +To run the test set after training completes, use this method. + +.. code-block:: python + + # run full training + trainer.fit(model) + + # (1) load the best checkpoint automatically (lightning tracks this for you) + trainer.test() + + # (2) don't load a checkpoint, instead use the model with the latest weights + trainer.test(ckpt_path=None) + + # (3) test using a specific checkpoint + trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt') + + # (4) test with an explicit model (will use this model and not load a checkpoint) + trainer.test(model) + +---------- + +Test multiple models +-------------------- +You can run the test set on multiple models using the same trainer instance. + +.. code-block:: python + + model1 = LitModel() + model2 = GANModel() + + trainer = Trainer() + trainer.test(model1) + trainer.test(model2) + +---------- + +Test pre-trained model +---------------------- +To run the test set on a pre-trained model, use this method. + +.. code-block:: python + + model = MyLightningModule.load_from_checkpoint( + checkpoint_path='/path/to/pytorch_checkpoint.ckpt', + hparams_file='/path/to/test_tube/experiment/version/hparams.yaml', + map_location=None + ) + + # init trainer with whatever options + trainer = Trainer(...) + + # test (pass in the model) + trainer.test(model) + +In this case, the options you pass to trainer will be used when +running the test set (ie: 16-bit, dp, ddp, etc...) + +---------- + +Test with additional data loaders +--------------------------------- +You can still run inference on a test set even if the `test_dataloader` method hasn't been +defined within your :doc:`lightning module <../common/lightning_module>` instance. This would be the case when your test data +is not available at the time your model was declared. + +.. code-block:: python + + # setup your data loader + test = DataLoader(...) + + # test (pass in the loader) + trainer.test(test_dataloaders=test) + +You can either pass in a single dataloader or a list of them. This optional named +parameter can be used in conjunction with any of the above use cases. Additionally, +you can also pass in an :doc:`datamodules <../extensions/datamodules>` that have overridden the +:ref:`datamodule-test-dataloader-label` method. + +.. code-block:: python + + class MyDataModule(pl.LightningDataModule): + ... + def test_dataloader(self): + return DataLoader(...) + + # setup your datamodule + dm = MyDataModule(...) + + # test (pass in datamodule) + trainer.test(datamodule=dm) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst new file mode 100644 index 00000000000000..96e19a7be46944 --- /dev/null +++ b/docs/source/common/trainer.rst @@ -0,0 +1,1707 @@ +.. role:: hidden + :class: hidden-section + +.. testsetup:: * + + import os + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.utilities.seed import seed_everything + +.. _trainer: + +Trainer +======= + +Once you've organized your PyTorch code into a LightningModule, +the Trainer automates everything else. + +.. raw:: html + + + +| + +This abstraction achieves the following: + +1. You maintain control over all aspects via PyTorch code without an added abstraction. + +2. The trainer uses best practices embedded by contributors and users + from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc... + +3. The trainer allows overriding any key part that you don't want automated. + +| + +----------- + +Basic use +--------- + +This is the basic use of the trainer: + +.. code-block:: python + + model = MyLightningModule() + + trainer = Trainer() + trainer.fit(model, train_dataloader, val_dataloader) + +-------- + +Under the hood +-------------- +Under the hood, the Lightning Trainer handles the training loop details for you, some examples include: + +- Automatically enabling/disabling grads +- Running the training, validation and test dataloaders +- Calling the Callbacks at the appropriate times +- Putting batches and computations on the correct devices + +Here's the pseudocode for what the trainer does under the hood (showing the train loop only) + +.. code-block:: python + + # put model in train mode + model.train() + torch.set_grad_enabled(True) + + losses = [] + for batch in train_dataloader: + # calls hooks like this one + on_train_batch_start() + + # train step + loss = training_step(batch) + + # clear gradients + optimizer.zero_grad() + + # backward + loss.backward() + + # update parameters + optimizer.step() + + losses.append(loss) + + +-------- + +Trainer in Python scripts +------------------------- +In Python scripts, it's recommended you use a main function to call the Trainer. + +.. code-block:: python + + from argparse import ArgumentParser + + def main(hparams): + model = LightningModule() + trainer = Trainer(gpus=hparams.gpus) + trainer.fit(model) + + if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--gpus', default=None) + args = parser.parse_args() + + main(args) + +So you can run it like so: + +.. code-block:: bash + + python main.py --gpus 2 + +.. note:: + + Pro-tip: You don't need to define all flags manually. Lightning can add them automatically + +.. code-block:: python + + from argparse import ArgumentParser + + def main(args): + model = LightningModule() + trainer = Trainer.from_argparse_args(args) + trainer.fit(model) + + if __name__ == '__main__': + parser = ArgumentParser() + parser = Trainer.add_argparse_args() + args = parser.parse_args() + + main(args) + +So you can run it like so: + +.. code-block:: bash + + python main.py --gpus 2 --max_steps 10 --limit_train_batches 10 --any_trainer_arg x + +.. note:: + If you want to stop a training run early, you can press "Ctrl + C" on your keyboard. + The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown, including + running callbacks such as ``on_train_end``. The trainer object will also set an attribute + ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute + resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs. + +------------ + +Validation +---------- +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or after it has already been trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + +Testing +------- +Once you're done training, feel free to run the test set! +(Only right before publishing your paper or pushing to production) + +.. code-block:: python + + trainer.test(test_dataloaders=test_dataloaders) + +------------ + +Deployment / prediction +----------------------- +You just trained a LightningModule which is also just a torch.nn.Module. +Use it to do whatever! + +.. code-block:: python + + # load model + pretrained_model = LightningModule.load_from_checkpoint(PATH) + pretrained_model.freeze() + + # use it for finetuning + def forward(self, x): + features = pretrained_model(x) + classes = classifier(features) + + # or for prediction + out = pretrained_model(x) + api_write({'response': out} + + +You may wish to run the model on a variety of devices. Instead of moving the data +manually to the correct device, decorate the forward method (or any other method you use for inference) +with :func:`~pytorch_lightning.core.decorators.auto_move_data` and Lightning will take care of the rest. + +------------ + +Reproducibility +--------------- + +To ensure full reproducibility from run to run you need to set seeds for pseudo-random generators, +and set ``deterministic`` flag in ``Trainer``. + +Example:: + + from pytorch_lightning import Trainer, seed_everything + + seed_everything(42) + # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. + model = Model() + trainer = Trainer(deterministic=True) + + +------- + +Trainer flags +------------- + +accelerator +^^^^^^^^^^^ + +.. raw:: html + + + +| + +The accelerator backend to use (previously known as distributed_backend). + +- (``'dp'``) is DataParallel (split batch among GPUs of same machine) +- (``'ddp'``) is DistributedDataParallel (each gpu on each node trains, and syncs grads) +- (``'ddp_cpu'``) is DistributedDataParallel on CPU (same as ``'ddp'``, but does not use GPUs. + Useful for multi-node CPU training or single-node debugging. Note that this will **not** give + a speedup on a single node, since Torch already makes efficient use of multiple CPUs on a single + machine.) +- (``'ddp2'``) dp on node, ddp across nodes. Useful for things like increasing + the number of negative samples + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(accelerator=None) + +Example:: + + # dp = DataParallel + trainer = Trainer(gpus=2, accelerator='dp') + + # ddp = DistributedDataParallel + trainer = Trainer(gpus=2, num_nodes=2, accelerator='ddp') + + # ddp2 = DistributedDataParallel + dp + trainer = Trainer(gpus=2, num_nodes=2, accelerator='ddp2') + +.. note:: This option does not apply to TPU. TPUs use ``'ddp'`` by default (over each core) + +You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs. + +Example:: + + class MyOwnAcc(Accelerator): + ... + + Trainer(accelerator=MyOwnAcc()) + +.. warning:: Passing in custom accelerators is experimental but work is in progress to enable full compatibility. + +accumulate_grad_batches +^^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Accumulates grads every k batches or as set up in the dict. +Trainer also calls ``optimizer.step()`` for the last indivisible step number. + +.. testcode:: + + # default used by the Trainer (no accumulation) + trainer = Trainer(accumulate_grad_batches=1) + +Example:: + + # accumulate every 4 batches (effective batch size is batch*4) + trainer = Trainer(accumulate_grad_batches=4) + + # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that + trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20}) + +amp_backend +^^^^^^^^^^^ + +.. raw:: html + + + +| + +Use PyTorch AMP ('native') (available PyTorch 1.6+), or NVIDIA apex ('apex'). + +.. testcode:: + + # using PyTorch built-in AMP, default used by the Trainer + trainer = Trainer(amp_backend='native') + + # using NVIDIA Apex + trainer = Trainer(amp_backend='apex') + +amp_level +^^^^^^^^^ + +.. raw:: html + + + +| + +The optimization level to use (O1, O2, etc...) +for 16-bit GPU precision (using NVIDIA apex under the hood). + +Check `NVIDIA apex docs `_ for level + +Example:: + + # default used by the Trainer + trainer = Trainer(amp_level='O2') + +auto_scale_batch_size +^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Automatically tries to find the largest batch size that fits into memory, +before any training. + +.. code-block:: + + # default used by the Trainer (no scaling of batch size) + trainer = Trainer(auto_scale_batch_size=None) + + # run batch size scaling, result overrides hparams.batch_size + trainer = Trainer(auto_scale_batch_size='binsearch') + + # call tune to find the batch size + trainer.tune(model) + +auto_select_gpus +^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +If enabled and `gpus` is an integer, pick available gpus automatically. +This is especially useful when GPUs are configured to be in "exclusive mode", +such that only one process at a time can access them. + +Example:: + + # no auto selection (picks first 2 gpus on system, may fail if other process is occupying) + trainer = Trainer(gpus=2, auto_select_gpus=False) + + # enable auto selection (will find two available gpus on system) + trainer = Trainer(gpus=2, auto_select_gpus=True) + + # specifies all GPUs regardless of its availability + Trainer(gpus=-1, auto_select_gpus=False) + + # specifies all available GPUs (if only one GPU is not occupied, uses one gpu) + Trainer(gpus=-1, auto_select_gpus=True) + +auto_lr_find +^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Runs a learning rate finder algorithm (see this `paper `_) +when calling trainer.tune(), to find optimal initial learning rate. + +.. code-block:: python + + # default used by the Trainer (no learning rate finder) + trainer = Trainer(auto_lr_find=False) + +Example:: + + # run learning rate finder, results override hparams.learning_rate + trainer = Trainer(auto_lr_find=True) + + # call tune to find the lr + trainer.tune(model) + +Example:: + + # run learning rate finder, results override hparams.my_lr_arg + trainer = Trainer(auto_lr_find='my_lr_arg') + + # call tune to find the lr + trainer.tune(model) + +.. note:: + See the :doc:`learning rate finder guide <../advanced/lr_finder>`. + +benchmark +^^^^^^^^^ + +.. raw:: html + + + +| + +If true enables cudnn.benchmark. +This flag is likely to increase the speed of your system if your +input sizes don't change. However, if it does, then it will likely +make your system slower. + +The speedup comes from allowing the cudnn auto-tuner to find the best +algorithm for the hardware `[see discussion here] +`_. + +Example:: + + # default used by the Trainer + trainer = Trainer(benchmark=False) + +deterministic +^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +If true enables cudnn.deterministic. +Might make your system slower, but ensures reproducibility. +Also sets ``$HOROVOD_FUSION_THRESHOLD=0``. + +For more info check `[pytorch docs] +`_. + +Example:: + + # default used by the Trainer + trainer = Trainer(deterministic=False) + +callbacks +^^^^^^^^^ + +.. raw:: html + + + +| + +Add a list of :class:`~pytorch_lightning.callbacks.Callback`. Callbacks run sequentially in the order defined here +with the exception of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks which run +after all others to ensure all states are saved to the checkpoints. + +.. code-block:: python + + # a list of callbacks + callbacks = [PrintCallback()] + trainer = Trainer(callbacks=callbacks) + +Example:: + + from pytorch_lightning.callbacks import Callback + + class PrintCallback(Callback): + def on_train_start(self, trainer, pl_module): + print("Training is started!") + def on_train_end(self, trainer, pl_module): + print("Training is done.") + + +Model-specific callbacks can also be added inside the ``LightningModule`` through +:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`. +Callbacks returned in this hook will extend the list initially given to the ``Trainer`` argument, and replace +the trainer callbacks should there be two or more of the same type. +:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks always run last. + + +check_val_every_n_epoch +^^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Check val every n train epochs. + +Example:: + + # default used by the Trainer + trainer = Trainer(check_val_every_n_epoch=1) + + # run val loop every 10 training epochs + trainer = Trainer(check_val_every_n_epoch=10) + +checkpoint_callback +^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, +Checkpoints capture the exact value of all parameters used by a model. +To disable automatic checkpointing, set this to `False`. + +.. code-block:: python + + # default used by Trainer + trainer = Trainer(checkpoint_callback=True) + + # turn off automatic checkpointing + trainer = Trainer(checkpoint_callback=False) + + +You can override the default behavior by initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` +callback, and adding it to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks` list. +See :doc:`Saving and Loading Weights <../common/weights_loading>` for how to customize checkpointing. + +.. testcode:: + + from pytorch_lightning.callbacks import ModelCheckpoint + # Init ModelCheckpoint callback, monitoring 'val_loss' + checkpoint_callback = ModelCheckpoint(monitor='val_loss') + + # Add your callback to the callbacks list + trainer = Trainer(callbacks=[checkpoint_callback]) + + +.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since + v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead. + + +default_root_dir +^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Default path for logs and weights when no logger or +:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On +certain clusters you might want to separate where logs and checkpoints are +stored. If you don't then use this argument for convenience. Paths can be local +paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials +will need to be set up to use remote filepaths. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(default_root_dir=os.getcwd()) + +distributed_backend +^^^^^^^^^^^^^^^^^^^ +Deprecated: This has been renamed ``accelerator``. + +fast_dev_run +^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test +to find any bugs (ie: a sort of unit test). + +Under the hood the pseudocode looks like this when running *fast_dev_run* with a single batch: + +.. code-block:: python + + # loading + __init__() + prepare_data + + # test training step + training_batch = next(train_dataloader) + training_step(training_batch) + + # test val step + val_batch = next(val_dataloader) + out = validation_step(val_batch) + validation_epoch_end([out]) + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(fast_dev_run=False) + + # runs 1 train, val, test batch and program ends + trainer = Trainer(fast_dev_run=True) + + # runs 7 train, val, test batches and program ends + trainer = Trainer(fast_dev_run=7) + +.. note:: + + This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will + disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like + ``LearningRateLogger`` and runs for only 1 epoch. This must be used only for debugging purposes. + ``limit_train/val/test_batches`` only limits the number of batches and won't disable anything. + +flush_logs_every_n_steps +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Writes logs to disk this often. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(flush_logs_every_n_steps=100) + +See Also: + - :doc:`logging <../extensions/logging>` + +gpus +^^^^ + +.. raw:: html + + + +| + +- Number of GPUs to train on (int) +- or which GPUs to train on (list) +- can handle strings + +.. testcode:: + + # default used by the Trainer (ie: train on CPU) + trainer = Trainer(gpus=None) + + # equivalent + trainer = Trainer(gpus=0) + +Example:: + + # int: train on 2 gpus + trainer = Trainer(gpus=2) + + # list: train on GPUs 1, 4 (by bus ordering) + trainer = Trainer(gpus=[1, 4]) + trainer = Trainer(gpus='1, 4') # equivalent + + # -1: train on all gpus + trainer = Trainer(gpus=-1) + trainer = Trainer(gpus='-1') # equivalent + + # combine with num_nodes to train on multiple GPUs across nodes + # uses 8 gpus in total + trainer = Trainer(gpus=2, num_nodes=4) + + # train only on GPUs 1 and 4 across nodes + trainer = Trainer(gpus=[1, 4], num_nodes=4) + +See Also: + - :doc:`Multi-GPU training guide <../advanced/multi_gpu>`. + +gradient_clip_val +^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Gradient clipping value + +- 0 means don't clip. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(gradient_clip_val=0.0) + +limit_train_batches +^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +How much of training dataset to check. +Useful when debugging or testing something that happens at the end of an epoch. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(limit_train_batches=1.0) + +Example:: + + # default used by the Trainer + trainer = Trainer(limit_train_batches=1.0) + + # run through only 25% of the training set each epoch + trainer = Trainer(limit_train_batches=0.25) + + # run through only 10 batches of the training set each epoch + trainer = Trainer(limit_train_batches=10) + +limit_test_batches +^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +How much of test dataset to check. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(limit_test_batches=1.0) + + # run through only 25% of the test set each epoch + trainer = Trainer(limit_test_batches=0.25) + + # run for only 10 batches + trainer = Trainer(limit_test_batches=10) + +In the case of multiple test dataloaders, the limit applies to each dataloader individually. + +limit_val_batches +^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +How much of validation dataset to check. +Useful when debugging or testing something that happens at the end of an epoch. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(limit_val_batches=1.0) + + # run through only 25% of the validation set each epoch + trainer = Trainer(limit_val_batches=0.25) + + # run for only 10 batches + trainer = Trainer(limit_val_batches=10) + +In the case of multiple validation dataloaders, the limit applies to each dataloader individually. + +log_every_n_steps +^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + + +How often to add logging rows (does not write to disk) + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(log_every_n_steps=50) + +See Also: + - :doc:`logging <../extensions/logging>` + +log_gpu_memory +^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Options: + +- None +- 'min_max' +- 'all' + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(log_gpu_memory=None) + + # log all the GPUs (on master node only) + trainer = Trainer(log_gpu_memory='all') + + # log only the min and max memory on the master node + trainer = Trainer(log_gpu_memory='min_max') + +.. note:: Might slow performance because it uses the output of ``nvidia-smi``. + +logger +^^^^^^ + +.. raw:: html + + + +| + +:doc:`Logger <../common/loggers>` (or iterable collection of loggers) for experiment tracking. + +.. testcode:: + + from pytorch_lightning.loggers import TensorBoardLogger + + # default logger used by trainer + logger = TensorBoardLogger( + save_dir=os.getcwd(), + version=1, + name='lightning_logs' + ) + Trainer(logger=logger) + +max_epochs +^^^^^^^^^^ + +.. raw:: html + + + +| + +Stop training once this number of epochs is reached + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(max_epochs=1000) + +min_epochs +^^^^^^^^^^ + +.. raw:: html + + + +| + +Force training for at least these many epochs + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(min_epochs=1) + +max_steps +^^^^^^^^^ + +.. raw:: html + + + +| + +Stop training after this number of steps +Training will stop if max_steps or max_epochs have reached (earliest). + +.. testcode:: + + # Default (disabled) + trainer = Trainer(max_steps=None) + + # Stop after 100 steps + trainer = Trainer(max_steps=100) + +min_steps +^^^^^^^^^ + +.. raw:: html + + + +| + +Force training for at least these number of steps. +Trainer will train model for at least min_steps or min_epochs (latest). + +.. testcode:: + + # Default (disabled) + trainer = Trainer(min_steps=None) + + # Run at least for 100 steps (disable min_epochs) + trainer = Trainer(min_steps=100, min_epochs=0) + +num_nodes +^^^^^^^^^ + +.. raw:: html + + + +| + +Number of GPU nodes for distributed training. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(num_nodes=1) + + # to train on 8 nodes + trainer = Trainer(num_nodes=8) + +num_processes +^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Number of processes to train with. Automatically set to the number of GPUs +when using ``accelerator="ddp"``. Set to a number greater than 1 when +using ``accelerator="ddp_cpu"`` to mimic distributed training on a +machine without GPUs. This is useful for debugging, but **will not** provide +any speedup, since single-process Torch already makes efficient use of multiple +CPUs. + +.. testcode:: + + # Simulate DDP for debugging on your GPU-less laptop + trainer = Trainer(accelerator="ddp_cpu", num_processes=2) + +num_sanity_val_steps +^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Sanity check runs n batches of val before starting the training routine. +This catches any bugs in your validation without having to wait for the first validation check. +The Trainer uses 2 steps by default. Turn it off or modify it here. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(num_sanity_val_steps=2) + + # turn it off + trainer = Trainer(num_sanity_val_steps=0) + + # check all validation data + trainer = Trainer(num_sanity_val_steps=-1) + + +This option will reset the validation dataloader unless ``num_sanity_val_steps=0``. + +overfit_batches +^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. +If the training dataloaders have `shuffle=True`, Lightning will automatically disable it. + +Useful for quickly debugging or trying to overfit on purpose. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(overfit_batches=0.0) + + # use only 1% of the train set (and use the train set for val and test) + trainer = Trainer(overfit_batches=0.01) + + # overfit on 10 of the same batches + trainer = Trainer(overfit_batches=10) + +plugins +^^^^^^^ + +.. raw:: html + + + +| + +Plugins allow you to connect arbitrary backends, precision libraries, SLURM, etc... For example: + +- DDP +- SLURM +- TorchElastic +- Apex + +To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own cluster. + +.. code-block:: python + + from pytorch_lightning.plugins.environments import cluster_environment + + class MyCluster(ClusterEnvironment): + + def master_address(self): + return your_master_address + + def master_port(self): + return your_master_port + + def world_size(self): + return the_world_size + + trainer = Trainer(cluster_environment=cluster_environment()) + +prepare_data_per_node +^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +If True will call `prepare_data()` on LOCAL_RANK=0 for every node. +If False will only call from NODE_RANK=0, LOCAL_RANK=0 + +.. testcode:: + + # default + Trainer(prepare_data_per_node=True) + + # use only NODE_RANK=0, LOCAL_RANK=0 + Trainer(prepare_data_per_node=False) + +precision +^^^^^^^^^ + +.. raw:: html + + + +| + +Double precision (64), full precision (32) or half precision (16). +Can be used on CPU, GPU or TPUs. + +If used on TPU will use torch.bfloat16 but tensor printing +will still show torch.float32. + +.. testcode:: + :skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE or not torch.cuda.is_available() + + # default used by the Trainer + trainer = Trainer(precision=32) + + # 16-bit precision + trainer = Trainer(precision=16, gpus=1) + + # 64-bit precision + trainer = Trainer(precision=64) + +Example:: + + # one day + trainer = Trainer(precision=8|4|2) + +process_position +^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Orders the progress bar. Useful when running multiple trainers on the same node. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(process_position=0) + +.. note:: This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + +profiler +^^^^^^^^ + +.. raw:: html + + + +| + +To profile individual steps during training and assist in identifying bottlenecks. + +See the :doc:`profiler documentation <../advanced/profiler>`. for more details. + +.. testcode:: + + from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler + + # default used by the Trainer + trainer = Trainer(profiler=None) + + # to profile standard training events, equivalent to `profiler=SimpleProfiler()` + trainer = Trainer(profiler="simple") + + # advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()` + trainer = Trainer(profiler="advanced") + +progress_bar_refresh_rate +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +How often to refresh progress bar (in steps). + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(progress_bar_refresh_rate=1) + + # disable progress bar + trainer = Trainer(progress_bar_refresh_rate=0) + +Note: + - In Google Colab notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates. + Lightning will set it to 20 in these environments if the user does not provide a value. + - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + +reload_dataloaders_every_epoch +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Set to True to reload dataloaders every epoch. + +.. code-block:: python + + # if False (default) + train_loader = model.train_dataloader() + for epoch in epochs: + for batch in train_loader: + ... + + # if True + for epoch in epochs: + train_loader = model.train_dataloader() + for batch in train_loader: + +replace_sampler_ddp +^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Enables auto adding of distributed sampler. By default it will add ``shuffle=True`` +for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize +it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. +If ``replace_sampler_ddp=True`` and a distributed sampler was already added, +Lightning will not replace the existing one. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(replace_sampler_ddp=True) + +By setting to False, you have to add your own distributed sampler: + +.. code-block:: python + + # default used by the Trainer + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) + dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) + +resume_from_checkpoint +^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +To resume training from a specific checkpoint pass in the path here. If resuming from a mid-epoch +checkpoint, training will start from the beginning of the next epoch. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(resume_from_checkpoint=None) + + # resume from a specific checkpoint + trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') + +sync_batchnorm +^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Enable synchronization between batchnorm layers across all GPUs. + +.. testcode:: + + trainer = Trainer(sync_batchnorm=True) + +track_grad_norm +^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +- no tracking (-1) +- Otherwise tracks that norm (2 for 2-norm) + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(track_grad_norm=-1) + + # track the 2-norm + trainer = Trainer(track_grad_norm=2) + +tpu_cores +^^^^^^^^^ + +.. raw:: html + + + +| + +- How many TPU cores to train on (1 or 8). +- Which TPU core to train on [1-8] + +A single TPU v2 or v3 has 8 cores. A TPU pod has +up to 2048 cores. A slice of a POD means you get as many cores +as you request. + +Your effective batch size is batch_size * total tpu cores. + +.. note:: + No need to add a :class:`~torch.utils.data.distributed.DistributedSampler`, + Lightning automatically does it for you. + +This parameter can be either 1 or 8. + +Example:: + + # your_trainer_file.py + + # default used by the Trainer (ie: train on CPU) + trainer = Trainer(tpu_cores=None) + + # int: train on a single core + trainer = Trainer(tpu_cores=1) + + # list: train on a single selected core + trainer = Trainer(tpu_cores=[2]) + + # int: train on all cores few cores + trainer = Trainer(tpu_cores=8) + + # for 8+ cores must submit via xla script with + # a max of 8 cores specified. The XLA script + # will duplicate script onto each TPU in the POD + trainer = Trainer(tpu_cores=8) + +To train on more than 8 cores (ie: a POD), +submit this script using the xla_dist script. + +Example:: + + python -m torch_xla.distributed.xla_dist + --tpu=$TPU_POD_NAME + --conda-env=torch-xla-nightly + --env=XLA_USE_BF16=1 + -- python your_trainer_file.py + +truncated_bptt_steps +^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Truncated back prop breaks performs backprop every k steps of +a much longer sequence. + +If this is enabled, your batches will automatically get truncated +and the trainer will apply Truncated Backprop to it. + +(`Williams et al. "An efficient gradient-based algorithm for on-line training of +recurrent network trajectories." +`_) + +.. testcode:: + + # default used by the Trainer (ie: disabled) + trainer = Trainer(truncated_bptt_steps=None) + + # backprop every 5 steps in a batch + trainer = Trainer(truncated_bptt_steps=5) + +.. note:: Make sure your batches have a sequence dimension. + +Lightning takes care to split your batch along the time-dimension. + +.. code-block:: python + + # we use the second as the time dimension + # (batch, time, ...) + sub_batch = batch[0, 0:t, ...] + +Using this feature requires updating your LightningModule's +:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg +with the hidden + +.. code-block:: python + + # Truncated back-propagation through time + def training_step(self, batch, batch_idx, hiddens): + # hiddens are the hiddens from the previous truncated backprop step + out, hiddens = self.lstm(data, hiddens) + return { + "loss": ..., + "hiddens": hiddens + } + +To modify how the batch is split, +override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: + +.. testcode:: + + class LitMNIST(LightningModule): + def tbptt_split_batch(self, batch, split_size): + # do your own splitting on the batch + return splits + +val_check_interval +^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +How often within one training epoch to check the validation set. +Can specify as float or int. + +- use (float) to check within a training epoch +- use (int) to check every n steps (batches) + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(val_check_interval=1.0) + + # check validation set 4 times during a training epoch + trainer = Trainer(val_check_interval=0.25) + + # check validation set every 1000 training batches + # use this when using iterableDataset and your dataset has no length + # (ie: production cases with streaming data) + trainer = Trainer(val_check_interval=1000) + + +weights_save_path +^^^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Directory of where to save weights if specified. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(weights_save_path=os.getcwd()) + + # save to your custom path + trainer = Trainer(weights_save_path='my/path') + +Example:: + + # if checkpoint callback used, then overrides the weights path + # **NOTE: this saves weights to some/path NOT my/path + checkpoint = ModelCheckpoint(dirpath='some/path') + trainer = Trainer( + callbacks=[checkpoint], + weights_save_path='my/path' + ) + +weights_summary +^^^^^^^^^^^^^^^ + +.. raw:: html + + + +| + +Prints a summary of the weights when training begins. +Options: 'full', 'top', None. + +.. testcode:: + + # default used by the Trainer (ie: print summary of top level modules) + trainer = Trainer(weights_summary='top') + + # print full summary of all modules and submodules + trainer = Trainer(weights_summary='full') + + # don't print a summary + trainer = Trainer(weights_summary=None) + +----- + +Trainer class API +----------------- + +Methods +^^^^^^^ + +init +**** + +.. automethod:: pytorch_lightning.trainer.Trainer.__init__ + :noindex: + +fit +**** + +.. automethod:: pytorch_lightning.trainer.Trainer.fit + :noindex: + +test +**** + +.. automethod:: pytorch_lightning.trainer.Trainer.test + :noindex: + +tune +**** + +.. automethod:: pytorch_lightning.trainer.Trainer.tune + :noindex: + +Properties +^^^^^^^^^^ + +callback_metrics +**************** + +The metrics available to callbacks. These are automatically set when you log via `self.log` + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('a_val', 2) + + + callback_metrics = trainer.callback_metrics + assert callback_metrics['a_val'] == 2 + +current_epoch +************* + +The current epoch + +.. code-block:: python + + def training_step(self, batch, batch_idx): + current_epoch = self.trainer.current_epoch + if current_epoch > 100: + # do something + pass + + +logger (p) +********** + +The current logger being used. Here's an example using tensorboard + +.. code-block:: python + + def training_step(self, batch, batch_idx): + logger = self.trainer.logger + tensorboard = logger.experiment + + +logged_metrics +************** + +The metrics sent to the logger (visualizer). + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('a_val', 2, log=True) + + + logged_metrics = trainer.logged_metrics + assert logged_metrics['a_val'] == 2 + +log_dir +******* +The directory for the current experiment. Use this to save images to, etc... + +.. code-block:: python + + def training_step(self, batch, batch_idx): + img = ... + save_img(img, self.trainer.log_dir) + + + +is_global_zero +************** + +Whether this process is the global zero in multi-node training + +.. code-block:: python + + def training_step(self, batch, batch_idx): + if self.trainer.is_global_zero: + print('in node 0, accelerator 0') + +progress_bar_metrics +******************** + +The metrics sent to the progress bar. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('a_val', 2, prog_bar=True) + + + progress_bar_metrics = trainer.progress_bar_metrics + assert progress_bar_metrics['a_val'] == 2 diff --git a/docs/source/common/weights_loading.rst b/docs/source/common/weights_loading.rst new file mode 100644 index 00000000000000..f9a4cbd1323495 --- /dev/null +++ b/docs/source/common/weights_loading.rst @@ -0,0 +1,215 @@ +.. testsetup:: * + + import os + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.core.lightning import LightningModule + +.. _weights_loading: + +########################## +Saving and loading weights +########################## + +Lightning automates saving and loading checkpoints. Checkpoints capture the exact value of all parameters used by a model. + +Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model. + + +***************** +Checkpoint saving +***************** +A Lightning checkpoint has everything needed to restore a training session including: + +- 16-bit scaling factor (apex) +- Current epoch +- Global step +- Model state_dict +- State of all optimizers +- State of all learningRate schedulers +- State of all callbacks +- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) + +Automatic saving +================ + +Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted. + +To change the checkpoint path pass in: + +.. code-block:: python + + # saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end + trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints') + +You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss: + +1. Calculate any metric or other quantity you wish to monitor, such as validation loss. +2. Log the quantity using :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method, with a key such as `val_loss`. +3. Initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` callback, and set `monitor` to be the key of your quantity. +4. Pass the callback to the `callbacks` :class:`~pytorch_lightning.trainer.Trainer` flag. + +.. testcode:: + + from pytorch_lightning.callbacks import ModelCheckpoint + + class LitAutoEncoder(LightningModule): + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + + # 1. calculate loss + loss = F.cross_entropy(y_hat, y) + + # 2. log `val_loss` + self.log('val_loss', loss) + + # 3. Init ModelCheckpoint callback, monitoring 'val_loss' + checkpoint_callback = ModelCheckpoint(monitor='val_loss') + + # 4. Add your callback to the callbacks list + trainer = Trainer(callbacks=[checkpoint_callback]) + +You can also control more advanced options, like `save_top_k`, to save the best k models and the `mode` of the monitored quantity (min/max), `save_weights_only` or `period` to set the interval of epochs between checkpoints, to avoid slowdowns. + +.. testcode:: + + from pytorch_lightning.callbacks import ModelCheckpoint + + class LitAutoEncoder(LightningModule): + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('val_loss', loss) + + # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt + checkpoint_callback = ModelCheckpoint( + monitor='val_loss', + dirpath='my/path/', + filename='sample-mnist-{epoch:02d}-{val_loss:.2f}', + save_top_k=3, + mode='min', + ) + + trainer = Trainer(callbacks=[checkpoint_callback]) + +You can retrieve the checkpoint after training by calling + +.. code-block:: python + + checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + trainer = Trainer(callbacks=[checkpoint_callback]) + trainer.fit(model) + checkpoint_callback.best_model_path + +Disabling checkpoints +--------------------- + +You can disable checkpointing by passing + +.. testcode:: + + trainer = Trainer(checkpoint_callback=False) + + +The Lightning checkpoint also saves the arguments passed into the LightningModule init +under the `hyper_parameters` key in the checkpoint. + +.. code-block:: python + + class MyLightningModule(LightningModule): + + def __init__(self, learning_rate, *args, **kwargs): + super().__init__() + self.save_hyperparameters() + + # all init args were saved to the checkpoint + checkpoint = torch.load(CKPT_PATH) + print(checkpoint['hyper_parameters']) + # {'learning_rate': the_value} + +Manual saving +============= +You can manually save checkpoints and restore your model from the checkpointed state. + +.. code-block:: python + + model = MyLightningModule(hparams) + trainer.fit(model) + trainer.save_checkpoint("example.ckpt") + new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt") + +Manual saving with accelerators +=============================== + +Lightning also handles accelerators where multiple processes are running, such as DDP. For example, when using the DDP accelerator our training script is running across multiple devices at the same time. +Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below. + +.. code-block:: python + + trainer = Trainer(accelerator="ddp") + model = MyLightningModule(hparams) + trainer.fit(model) + # Saves only on the main process + trainer.save_checkpoint("example.ckpt") + +Not using `trainer.save_checkpoint` can lead to unexpected behaviour and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the trainer's save functionality. +If using custom saving functions cannot be avoided, we recommend using :func:`~pytorch_lightning.loggers.base.rank_zero_only` to ensure saving occurs only on the main process. + +****************** +Checkpoint loading +****************** + +To load a model along with its weights, biases and hyperparameters use the following method: + +.. code-block:: python + + model = MyLightingModule.load_from_checkpoint(PATH) + + print(model.learning_rate) + # prints the learning_rate you used in this checkpoint + + model.eval() + y_hat = model(x) + +But if you don't want to use the values saved in the checkpoint, pass in your own here + +.. testcode:: + + class LitModel(LightningModule): + + def __init__(self, in_dim, out_dim): + super().__init__() + self.save_hyperparameters() + self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim) + +you can restore the model like this + +.. code-block:: python + + # if you train and save the model like this it will use these values when loading + # the weights. But you can overwrite this + LitModel(in_dim=32, out_dim=10) + + # uses in_dim=32, out_dim=10 + model = LitModel.load_from_checkpoint(PATH) + + # uses in_dim=128, out_dim=10 + model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10) + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.load_from_checkpoint + :noindex: + +Restoring Training State +======================== + +If you don't just want to load weights, but instead restore the full training, +do the following: + +.. code-block:: python + + model = LitModel() + trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') + + # automatically restores model, epoch, step, LR schedulers, apex, etc... + trainer.fit(model) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7ca48bd19cbf3d..1c1f3be8a636aa 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,26 +12,27 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -import os -import sys +# import m2r import glob +import os import shutil -import inspect +import sys -# import m2r -import builtins import pt_lightning_sphinx_theme -from sphinx.ext import apidoc PATH_HERE = os.path.abspath(os.path.dirname(__file__)) PATH_ROOT = os.path.join(PATH_HERE, '..', '..') sys.path.insert(0, os.path.abspath(PATH_ROOT)) -builtins.__LIGHTNING_SETUP__ = True +FOLDER_GENERATED = 'generated' +SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) -IS_READTHEDOCS_BUILD = os.environ.get('READTHEDOCS', False) - -import pytorch_lightning # noqa: E402 +try: + from pytorch_lightning import info +except ImportError: + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append(os.path.join(PATH_ROOT, "pytorch_lightning")) + import info # -- Project documents ------------------------------------------------------- @@ -51,25 +52,48 @@ # with open('readme.md', 'w') as fp: # fp.write(readme) + +def _transform_changelog(path_in: str, path_out: str) -> None: + with open(path_in, 'r') as fp: + chlog_lines = fp.readlines() + # enrich short subsub-titles to be unique + chlog_ver = '' + for i, ln in enumerate(chlog_lines): + if ln.startswith('## '): + chlog_ver = ln[2:].split('-')[0].strip() + elif ln.startswith('### '): + ln = ln.replace('###', f'### {chlog_ver} -') + chlog_lines[i] = ln + with open(path_out, 'w') as fp: + fp.writelines(chlog_lines) + + +os.makedirs(os.path.join(PATH_HERE, FOLDER_GENERATED), exist_ok=True) +# copy all documents from GH templates like contribution guide for md in glob.glob(os.path.join(PATH_ROOT, '.github', '*.md')): - shutil.copy(md, os.path.join(PATH_HERE, os.path.basename(md))) + shutil.copy(md, os.path.join(PATH_HERE, FOLDER_GENERATED, os.path.basename(md))) +# copy also the changelog +_transform_changelog( + os.path.join(PATH_ROOT, 'CHANGELOG.md'), + os.path.join(PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'), +) # -- Project information ----------------------------------------------------- -project = 'PyTorch-Lightning' -copyright = pytorch_lightning.__copyright__ -author = pytorch_lightning.__author__ +project = 'PyTorch Lightning' +copyright = info.__copyright__ +author = info.__author__ # The short X.Y version -version = pytorch_lightning.__version__ +version = info.__version__ # The full version, including alpha/beta/rc tags -release = pytorch_lightning.__version__ +release = info.__version__ # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '2.0' +needs_sphinx = '3.4' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -82,15 +106,18 @@ 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 'sphinx.ext.coverage', - 'sphinx.ext.linkcode', + 'sphinx.ext.viewcode', 'sphinx.ext.autosummary', 'sphinx.ext.napoleon', + 'sphinx.ext.imgmath', 'recommonmark', 'sphinx.ext.autosectionlabel', # 'm2r', - 'nbsphinx', + # 'nbsphinx', # it seems some sphinx issue 'sphinx_autodoc_typehints', + 'sphinx_copybutton', 'sphinx_paramlinks', + 'sphinx_togglebutton', ] # Add any paths that contain templates here, relative to this directory. @@ -130,18 +157,7 @@ # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [ - 'api/pytorch_lightning.rst', - 'api/pl_examples.*', - 'api/modules.rst', - - # deprecated/renamed: - 'api/pytorch_lightning.loggers.comet_logger.rst', # TODO: remove in v0.8.0 - 'api/pytorch_lightning.loggers.mlflow_logger.rst', # TODO: remove in v0.8.0 - 'api/pytorch_lightning.loggers.test_tube_logger.rst', # TODO: remove in v0.8.0 - 'api/pytorch_lightning.callbacks.pt_callbacks.*', # TODO: remove in v0.8.0 - 'api/pytorch_lightning.pt_overrides.*', # TODO: remove in v0.8.0 - 'api/pytorch_lightning.root_module.*', # TODO: remove in v0.8.0 - 'api/pytorch_lightning.logging.*', # TODO: remove in v0.8.0 + f'{FOLDER_GENERATED}/PULL_REQUEST_TEMPLATE.md', ] # The name of the Pygments (syntax highlighting) style to use. @@ -162,19 +178,21 @@ # documentation. html_theme_options = { - 'pytorch_project': pytorch_lightning.__homepage__, - 'canonical_url': pytorch_lightning.__homepage__, + 'pytorch_project': 'https://pytorchlightning.ai', + 'canonical_url': info.__docs_url__, 'collapse_navigation': False, 'display_version': True, 'logo_only': False, } -html_logo = '_images/logos/lightning_logo-name.svg' +html_logo = '_static/images/logo.svg' + +html_favicon = '_static/images/icon.svg' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_images', '_templates'] +html_static_path = ['_templates', '_static'] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -186,7 +204,6 @@ # # html_sidebars = {} - # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. @@ -219,9 +236,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, project, project + ' Documentation', [author], 1) -] +man_pages = [(master_doc, project, project + ' Documentation', [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -229,8 +244,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, project, project + ' Documentation', author, project, - 'One line description of project.', 'Miscellaneous'), + ( + master_doc, + project, + project + ' Documentation', + author, + project, + 'One line description of project.', + 'Miscellaneous', + ), ] # -- Options for Epub output ------------------------------------------------- @@ -257,8 +279,9 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/3', None), 'torch': ('https://pytorch.org/docs/stable/', None), - 'numpy': ('https://docs.scipy.org/doc/numpy/', None), + 'numpy': ('https://numpy.org/doc/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), + 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/stable/', None), } # -- Options for todo extension ---------------------------------------------- @@ -267,131 +290,72 @@ todo_include_todos = True -# packages for which sphinx-apidoc should generate the docs (.rst files) -PACKAGES = [ - pytorch_lightning.__name__, - 'pl_examples', -] - -apidoc_output_folder = os.path.join(PATH_HERE, 'api') - - -def run_apidoc(_): - sys.path.insert(0, apidoc_output_folder) - - # delete api-doc files before generating them - if os.path.exists(apidoc_output_folder): - shutil.rmtree(apidoc_output_folder) - - for pkg in PACKAGES: - argv = ['-e', - '-o', apidoc_output_folder, - os.path.join(PATH_ROOT, pkg), - '**/test_*', - '--force', - '--private', - '--module-first'] - - apidoc.main(argv) - - def setup(app): - app.connect('builder-inited', run_apidoc) + # this is for hiding doctest decoration, + # see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/ + app.add_js_file('copybutton.js') + app.add_css_file('main.css') # copy all notebooks to local folder -path_nbs = os.path.join(PATH_HERE, 'notebooks') -if not os.path.isdir(path_nbs): - os.mkdir(path_nbs) -for path_ipynb in glob.glob(os.path.join(PATH_ROOT, 'notebooks', '*.ipynb')): - path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb)) - shutil.copy(path_ipynb, path_ipynb2) +# path_nbs = os.path.join(PATH_HERE, 'notebooks') +# if not os.path.isdir(path_nbs): +# os.mkdir(path_nbs) +# for path_ipynb in glob.glob(os.path.join(PATH_ROOT, 'notebooks', '*.ipynb')): +# path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb)) +# shutil.copy(path_ipynb, path_ipynb2) # Ignoring Third-party packages # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule def package_list_from_file(file): + """List up package name (not containing version and extras) from a package list file + """ mocked_packages = [] with open(file, 'r') as fp: for ln in fp.readlines(): - found = [ln.index(ch) for ch in list(',=<>#') if ch in ln] + # Example: `tqdm>=4.41.0` => `tqdm` + # `[` is for package with extras + found = [ln.index(ch) for ch in list(',=<>#[') if ch in ln] pkg = ln[:min(found)] if found else ln if pkg.rstrip(): mocked_packages.append(pkg.rstrip()) return mocked_packages -MOCK_PACKAGES = package_list_from_file(os.path.join(PATH_ROOT, 'requirements-extra.txt')) -if IS_READTHEDOCS_BUILD: +# define mapping from PyPI names to python imports +PACKAGE_MAPPING = { + 'Pillow': 'PIL', + 'opencv-python': 'cv2', + 'PyYAML': 'yaml', + 'comet-ml': 'comet_ml', + 'neptune-client': 'neptune', + 'hydra-core': 'hydra', + 'pyDeprecate': 'deprecate', +} +MOCK_PACKAGES = [] +if SPHINX_MOCK_REQUIREMENTS: + MOCK_PACKAGES += ['fairscale'] # mock also base packages when we are on RTD since we don't install them there - base_packages = package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) - MOCK_PACKAGES.extend(base_packages) + MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) + MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt')) + MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'loggers.txt')) +MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES] -MOCK_MANUAL_PACKAGES = [ - 'torchvision', - 'PIL', -] -autodoc_mock_imports = MOCK_PACKAGES + MOCK_MANUAL_PACKAGES - - -# Options for the linkcode extension -# ---------------------------------- -github_user = 'PyTorchLightning' -github_repo = project - - -# Resolve function -# This function is used to populate the (source) links in the API -def linkcode_resolve(domain, info): - def find_source(): - # try to find the file and line number, based on code from numpy: - # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286 - obj = sys.modules[info['module']] - for part in info['fullname'].split('.'): - obj = getattr(obj, part) - fname = inspect.getsourcefile(obj) - # https://github.com/rtfd/readthedocs.org/issues/5735 - if any([s in fname for s in ('readthedocs', 'rtfd', 'checkouts')]): - # /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/ - # devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176 - path_top = os.path.abspath(os.path.join('..', '..', '..')) - fname = os.path.relpath(fname, start=path_top) - else: - # Local build, imitate master - fname = 'master/' + os.path.relpath(fname, start=os.path.abspath('..')) - source, lineno = inspect.getsourcelines(obj) - return fname, lineno, lineno + len(source) - 1 - - if domain != 'py' or not info['module']: - return None - try: - filename = '%s#L%d-L%d' % find_source() - except Exception: - filename = info['module'].replace('.', '/') + '.py' - # import subprocess - # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE, - # universal_newlines=True).communicate()[0][:-1] - branch = filename.split('/')[0] - # do mapping from latest tags to master - branch = {'latest': 'master', 'stable': 'master'}.get(branch, branch) - filename = '/'.join([branch] + filename.split('/')[1:]) - return "https://github.com/%s/%s/blob/%s" \ - % (github_user, github_repo, filename) +autodoc_mock_imports = MOCK_PACKAGES +autosummary_generate = True autodoc_member_order = 'groupwise' + autoclass_content = 'both' -# the options are fixed and will be soon in release, -# see https://github.com/sphinx-doc/sphinx/issues/5459 + autodoc_default_options = { - 'members': None, - 'methods': None, - # 'attributes': None, + 'members': True, + 'methods': True, 'special-members': '__call__', 'exclude-members': '_abc_impl', 'show-inheritance': True, - 'private-members': True, - 'noindex': True, } # Sphinx will add “permalinks” for each heading and description environment as paragraph signs that @@ -410,12 +374,21 @@ def find_source(): # only run doctests marked with a ".. doctest::" directive doctest_test_doctest_blocks = '' doctest_global_setup = """ - import importlib import os +from typing import Optional import torch - -TORCHVISION_AVAILABLE = importlib.util.find_spec('torchvision') - +from torch import nn +import pytorch_lightning as pl +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.utilities import ( + _NATIVE_AMP_AVAILABLE, + _APEX_AVAILABLE, + _XLA_AVAILABLE, + _TPU_AVAILABLE, + _TORCHVISION_AVAILABLE, + _module_available, +) +TORCHVISION_AVAILABLE = _module_available("torchvision") """ coverage_skip_undoc_in_source = True diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst deleted file mode 100644 index 412b6d613ecc6d..00000000000000 --- a/docs/source/debugging.rst +++ /dev/null @@ -1,82 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - -Debugging -========= -The following are flags that make debugging much easier. - -Fast dev run ------------- -This flag runs a "unit test" by running 1 training batch and 1 validation batch. -The point is to detect any bugs in the training/validation loop without having to wait for -a full epoch to crash. - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) - -.. testcode:: - - trainer = Trainer(fast_dev_run=True) - -Inspect gradient norms ----------------------- -Logs (to a logger), the norm of each weight matrix. - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.track_grad_norm` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) - -.. testcode:: - - # the 2-norm - trainer = Trainer(track_grad_norm=2) - -Log GPU usage -------------- -Logs (to a logger) the GPU usage for each GPU on the master machine. - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.log_gpu_memory` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) - -.. testcode:: - - trainer = Trainer(log_gpu_memory=True) - -Make model overfit on subset of data ------------------------------------- - -A good debugging technique is to take a tiny portion of your data (say 2 samples per class), -and try to get your model to overfit. If it can't, it's a sign it won't work with large datasets. - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_pct` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) - -.. testcode:: - - trainer = Trainer(overfit_pct=0.01) - -Print the parameter count by layer ----------------------------------- -Whenever the .fit() function gets called, the Trainer will print the weights summary for the lightningModule. -To disable this behavior, turn off this flag: - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) - -.. testcode:: - - trainer = Trainer(weights_summary=None) - - -Set the number of validation sanity steps ------------------------------------------ -Lightning runs a few steps of validation in the beginning of training. -This avoids crashing in the validation loop sometime deep into a lengthy training loop. - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.num_sanity_val_steps` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) - -.. testcode:: - - # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) \ No newline at end of file diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst deleted file mode 100644 index a0bfc83ec27d9c..00000000000000 --- a/docs/source/early_stopping.rst +++ /dev/null @@ -1,66 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.callbacks.early_stopping import EarlyStopping - - -Early stopping -============== - -Stopping an epoch early ------------------------ -You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.on_batch_start` to return `-1` when some condition is met. - -If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run. - -Default Epoch End Callback Behavior ------------------------------------ -By default early stopping will be enabled if `'val_loss'` -is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s -return dict. Otherwise training will proceed with early stopping disabled. - -Enable Early Stopping using Callbacks on epoch end --------------------------------------------------- -There are two ways to enable early stopping using callbacks on epoch end. - -- Set early_stop_callback to True. Will look for 'val_loss' in validation_epoch_end() return dict. - If it is not found an error is raised. - - .. testcode:: - - trainer = Trainer(early_stop_callback=True) - -- Or configure your own callback - - .. testcode:: - - early_stop_callback = EarlyStopping( - monitor='val_loss', - min_delta=0.00, - patience=3, - verbose=False, - mode='min' - ) - trainer = Trainer(early_stop_callback=early_stop_callback) - -In any case, the callback will fall back to the training metrics (returned in -:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`, -:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`) -looking for a key to monitor if validation is disabled or -:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` -is not defined. - -.. seealso:: - - :class:`~pytorch_lightning.trainer.trainer.Trainer` - - :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` - -Disable Early Stopping with callbacks on epoch end --------------------------------------------------- -To disable early stopping pass ``False`` to the -:paramref:`~pytorch_lightning.trainer.trainer.Trainer.early_stop_callback`. -Note that ``None`` will not disable early stopping but will lead to the -default behaviour. - -.. seealso:: - - :class:`~pytorch_lightning.trainer.trainer.Trainer` - - :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` diff --git a/docs/source/ecosystem/asr_nlp_tts.rst b/docs/source/ecosystem/asr_nlp_tts.rst new file mode 100644 index 00000000000000..af9a7084583f2f --- /dev/null +++ b/docs/source/ecosystem/asr_nlp_tts.rst @@ -0,0 +1,802 @@ +################# +Conversational AI +################# + +These are amazing ecosystems to help with Automatic Speech Recognition (ASR), Natural Language Processing (NLP), and Text to speech (TTS). + +---- + +**** +NeMo +**** + +`NVIDIA NeMo `_ is a toolkit for building new State-of-the-Art +Conversational AI models. NeMo has separate collections for Automatic Speech Recognition (ASR), +Natural Language Processing (NLP), and Text-to-Speech (TTS) models. Each collection consists of +prebuilt modules that include everything needed to train on your data. +Every module can easily be customized, extended, and composed to create new Conversational AI +model architectures. + +Conversational AI architectures are typically very large and require a lot of data and compute +for training. NeMo uses PyTorch Lightning for easy and performant multi-GPU/multi-node +mixed-precision training. + +.. note:: Every NeMo model is a LightningModule that comes equipped with all supporting infrastructure for training and reproducibility. + +---------- + +NeMo Models +=========== + +NeMo Models contain everything needed to train and reproduce state of the art Conversational AI +research and applications, including: + +- neural network architectures +- datasets/data loaders +- data preprocessing/postprocessing +- data augmentors +- optimizers and schedulers +- tokenizers +- language models + +NeMo uses `Hydra `_ for configuring both NeMo models and the PyTorch Lightning Trainer. +Depending on the domain and application, many different AI libraries will have to be configured +to build the application. Hydra makes it easy to bring all of these libraries together +so that each can be configured from .yaml or the Hydra CLI. + +.. note:: Every NeMo model has an example configuration file and a corresponding script that contains all configurations needed for training. + +The end result of using NeMo, Pytorch Lightning, and Hydra is that +NeMo models all have the same look and feel. This makes it easy to do Conversational AI research +across multiple domains. NeMo models are also fully compatible with the PyTorch ecosystem. + +Installing NeMo +--------------- + +Before installing NeMo, please install Cython first. + +.. code-block:: bash + + pip install Cython + +For ASR and TTS models, also install these linux utilities. + +.. code-block:: bash + + apt-get update && apt-get install -y libsndfile1 ffmpeg + +Then installing the latest NeMo release is a simple pip install. + +.. code-block:: bash + + pip install nemo_toolkit[all]==1.0.0b1 + +To install the main branch from GitHub: + +.. code-block:: bash + + python -m pip install git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[all] + +To install from a local clone of NeMo: + +.. code-block:: bash + + ./reinstall.sh # from cloned NeMo's git root + +For Docker users, the NeMo container is available on +`NGC `_. + +.. code-block:: bash + + docker pull nvcr.io/nvidia/nemo:v1.0.0b1 + +.. code-block:: bash + + docker run --runtime=nvidia -it --rm -v --shm-size=8g -p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/nemo:v1.0.0b1 + +Experiment Manager +------------------ + +NeMo's Experiment Manager leverages PyTorch Lightning for model checkpointing, +TensorBoard Logging, and Weights and Biases logging. The Experiment Manager is included by default +in all NeMo example scripts. + +.. code-block:: python + + exp_manager(trainer, cfg.get("exp_manager", None)) + +And is configurable via .yaml with Hydra. + +.. code-block:: bash + + exp_manager: + exp_dir: null + name: *name + create_tensorboard_logger: True + create_checkpoint_callback: True + +Optionally launch Tensorboard to view training results in ./nemo_experiments (by default). + +.. code-block:: bash + + tensorboard --bind_all --logdir nemo_experiments + +-------- + +Automatic Speech Recognition (ASR) +================================== + +Everything needed to train Convolutional ASR models is included with NeMo. +NeMo supports multiple Speech Recognition architectures, including Jasper and QuartzNet. +`NeMo Speech Models `_ +can be trained from scratch on custom datasets or +fine-tuned using pre-trained checkpoints trained on thousands of hours of audio +that can be restored for immediate use. + +Some typical ASR tasks are included with NeMo: + +- `Audio transcription `_ +- `Byte Pair/Word Piece Training `_ +- `Speech Commands `_ +- `Voice Activity Detection `_ +- `Speaker Recognition `_ + +See this `asr notebook `_ +for a full tutorial on doing ASR with NeMo, PyTorch Lightning, and Hydra. + +Specify ASR Model Configurations with YAML File +----------------------------------------------- + +NeMo Models and the PyTorch Lightning Trainer can be fully configured from .yaml files using Hydra. + +See this `asr config `_ +for the entire speech to text .yaml file. + +.. code-block:: yaml + + # configure the PyTorch Lightning Trainer + trainer: + gpus: 0 # number of gpus + max_epochs: 5 + max_steps: null # computed at runtime if not set + num_nodes: 1 + distributed_backend: ddp + ... + # configure the ASR model + model: + ... + encoder: + cls: nemo.collections.asr.modules.ConvASREncoder + params: + feat_in: *n_mels + activation: relu + conv_mask: true + + jasper: + - filters: 128 + repeat: 1 + kernel: [11] + stride: [1] + dilation: [1] + dropout: *dropout + ... + # all other configuration, data, optimizer, preprocessor, etc + ... + +Developing ASR Model From Scratch +--------------------------------- + +`speech_to_text.py `_ + +.. code-block:: python + + # hydra_runner calls hydra.main and is useful for multi-node experiments + @hydra_runner(config_path="conf", config_name="config") + def main(cfg): + trainer = Trainer(**cfg.trainer) + asr_model = EncDecCTCModel(cfg.model, trainer) + trainer.fit(asr_model) + + +Hydra makes every aspect of the NeMo model, +including the PyTorch Lightning Trainer, customizable from the command line. + +.. code-block:: bash + + python NeMo/examples/asr/speech_to_text.py --config-name=quartznet_15x5 \ + trainer.gpus=4 \ + trainer.max_epochs=128 \ + +trainer.precision=16 \ + model.train_ds.manifest_filepath=/librispeech-train-all.json \ + model.validation_ds.manifest_filepath=/librispeech-dev-other.json \ + model.train_ds.batch_size=64 \ + +model.validation_ds.num_workers=16 \ + +model.train_ds.num_workers=16 + +.. note:: Training NeMo ASR models can take days/weeks so it is highly recommended to use multiple GPUs and multiple nodes with the PyTorch Lightning Trainer. + + +Using State-Of-The-Art Pre-trained ASR Model +-------------------------------------------- + +Transcribe audio with QuartzNet model pretrained on ~3300 hours of audio. + +.. code-block:: python + + quartznet = EncDecCTCModel.from_pretrained('QuartzNet15x5Base-En') + + files = ['path/to/my.wav'] # file duration should be less than 25 seconds + + for fname, transcription in zip(files, quartznet.transcribe(paths2audio_files=files)): + print(f"Audio in {fname} was recognized as: {transcription}") + +To see the available pretrained checkpoints: + +.. code-block:: python + + EncDecCTCModel.list_available_models() + +NeMo ASR Model Under the Hood +----------------------------- + +Any aspect of ASR training or model architecture design can easily be customized +with PyTorch Lightning since every NeMo model is a Lightning Module. + +.. code-block:: python + + class EncDecCTCModel(ASRModel): + """Base class for encoder decoder CTC-based models.""" + ... + @typecheck() + def forward(self, input_signal, input_signal_length): + processed_signal, processed_signal_len = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal) + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + log_probs = self.decoder(encoder_output=encoded) + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + return log_probs, encoded_len, greedy_predictions + + # PTL-specific methods + def training_step(self, batch, batch_nb): + audio_signal, audio_signal_len, transcript, transcript_len = batch + log_probs, encoded_len, predictions = self.forward( + input_signal=audio_signal, input_signal_length=audio_signal_len + ) + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + wer_num, wer_denom = self._wer(predictions, transcript, transcript_len) + self.log_dict({ + 'train_loss': loss_value, + 'training_batch_wer': wer_num / wer_denom, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + }) + return loss_value + +Neural Types in NeMo ASR +------------------------ + +NeMo Models and Neural Modules come with Neural Type checking. +Neural type checking is extremely useful when combining many different neural +network architectures for a production-grade application. + +.. code-block:: python + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), audio_eltype), + "input_signal_length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "outputs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "greedy_predictions": NeuralType(('B', 'T'), LabelsType()), + } + +-------- + +Natural Language Processing (NLP) +================================= + +Everything needed to finetune BERT-like language models for NLP tasks is included with NeMo. +`NeMo NLP Models `_ +include `HuggingFace Transformers `_ +and `NVIDIA Megatron-LM `_ BERT and Bio-Megatron models. +NeMo can also be used for pretraining BERT-based language models from HuggingFace. + +Any of the HuggingFace encoders or Megatron-LM encoders can easily be used for the NLP tasks +that are included with NeMo: + +- `Glue Benchmark (All tasks) `_ +- `Intent Slot Classification `_ +- `Language Modeling (BERT Pretraining) `_ +- `Question Answering `_ +- `Text Classification `_ (including Sentiment Analysis) +- `Token Classification `_ (including Named Entity Recognition) +- `Punctuation and Capitalization `_ + +Named Entity Recognition (NER) +------------------------------ + +NER (or more generally token classification) is the NLP task of detecting and classifying key information (entities) in text. +This task is very popular in Healthcare and Finance. In finance, for example, it can be important to identify +geographical, geopolitical, organizational, persons, events, and natural phenomenon entities. +See this `NER notebook `_ +for a full tutorial on doing NER with NeMo, PyTorch Lightning, and Hydra. + +Specify NER Model Configurations with YAML File +----------------------------------------------- + +.. note:: NeMo Models and the PyTorch Lightning Trainer can be fully configured from .yaml files using Hydra. + +See this `token classification config `_ +for the entire NER (token classification) .yaml file. + +.. code-block:: yaml + + # configure any argument of the PyTorch Lightning Trainer + trainer: + gpus: 1 # the number of gpus, 0 for CPU + num_nodes: 1 + max_epochs: 5 + ... + # configure any aspect of the token classification model here + model: + dataset: + data_dir: ??? # /path/to/data + class_balancing: null # choose from [null, weighted_loss]. Weighted_loss enables the weighted class balancing of the loss, may be used for handling unbalanced classes + max_seq_length: 128 + ... + tokenizer: + tokenizer_name: ${model.language_model.pretrained_model_name} # or sentencepiece + vocab_file: null # path to vocab file + ... + # the language model can be from HuggingFace or Megatron-LM + language_model: + pretrained_model_name: bert-base-uncased + lm_checkpoint: null + ... + # the classifier for the downstream task + head: + num_fc_layers: 2 + fc_dropout: 0.5 + activation: 'relu' + ... + # all other configuration: train/val/test/ data, optimizer, experiment manager, etc + ... + +Developing NER Model From Scratch +--------------------------------- + +`token_classification.py `_ + +.. code-block:: python + + # hydra_runner calls hydra.main and is useful for multi-node experiments + @hydra_runner(config_path="conf", config_name="token_classification_config") + def main(cfg: DictConfig) -> None: + trainer = pl.Trainer(**cfg.trainer) + model = TokenClassificationModel(cfg.model, trainer=trainer) + trainer.fit(model) + +After training, we can do inference with the saved NER model using PyTorch Lightning. + +Inference from file: + +.. code-block:: python + + gpu = 1 if cfg.trainer.gpus != 0 else 0 + trainer = pl.Trainer(gpus=gpu) + model.set_trainer(trainer) + model.evaluate_from_file( + text_file=os.path.join(cfg.model.dataset.data_dir, cfg.model.validation_ds.text_file), + labels_file=os.path.join(cfg.model.dataset.data_dir, cfg.model.validation_ds.labels_file), + output_dir=exp_dir, + add_confusion_matrix=True, + normalize_confusion_matrix=True, + ) + +Or we can run inference on a few examples: + +.. code-block:: python + + queries = ['we bought four shirts from the nvidia gear store in santa clara.', 'Nvidia is a company in Santa Clara.'] + results = model.add_predictions(queries) + + for query, result in zip(queries, results): + logging.info(f'Query : {query}') + logging.info(f'Result: {result.strip()}\n') + +Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trainer, customizable from the command line. + +.. code-block:: bash + + python token_classification.py \ + model.language_model.pretrained_model_name=bert-base-cased \ + model.head.num_fc_layers=2 \ + model.dataset.data_dir=/path/to/my/data \ + trainer.max_epochs=5 \ + trainer.gpus=[0,1] + +----------- + +Tokenizers +---------- + +Tokenization is the process of converting natural language text into integer arrays +which can be used for machine learning. +For NLP tasks, tokenization is an essential part of data preprocessing. +NeMo supports all BERT-like model tokenizers from +`HuggingFace's AutoTokenizer `_ +and also supports `Google's SentencePieceTokenizer `_ +which can be trained on custom data. + +To see the list of supported tokenizers: + +.. code-block:: python + + from nemo.collections import nlp as nemo_nlp + + nemo_nlp.modules.get_tokenizer_list() + +See this `tokenizer notebook `_ +for a full tutorial on using tokenizers in NeMo. + +Language Models +--------------- + +Language models are used to extract information from (tokenized) text. +Much of the state-of-the-art in natural language processing is achieved +by fine-tuning pretrained language models on the downstream task. + +With NeMo, you can either `pretrain `_ +a BERT model on your data or use a pretrained language model from `HuggingFace Transformers `_ +or `NVIDIA Megatron-LM `_. + +To see the list of language models available in NeMo: + +.. code-block:: python + + nemo_nlp.modules.get_pretrained_lm_models_list(include_external=True) + +Easily switch between any language model in the above list by using `.get_lm_model`. + +.. code-block:: python + + nemo_nlp.modules.get_lm_model(pretrained_model_name='distilbert-base-uncased') + +See this `language model notebook `_ +for a full tutorial on using pretrained language models in NeMo. + +Using a Pre-trained NER Model +----------------------------- + +NeMo has pre-trained NER models that can be used +to get started with Token Classification right away. +Models are automatically downloaded from NGC, +cached locally to disk, +and loaded into GPU memory using the `.from_pretrained` method. + +.. code-block:: python + + # load pre-trained NER model + pretrained_ner_model = TokenClassificationModel.from_pretrained(model_name="NERModel") + + # define the list of queries for inference + queries = [ + 'we bought four shirts from the nvidia gear store in santa clara.', + 'Nvidia is a company.', + 'The Adventures of Tom Sawyer by Mark Twain is an 1876 novel about a young boy growing ' + + 'up along the Mississippi River.', + ] + results = pretrained_ner_model.add_predictions(queries) + + for query, result in zip(queries, results): + print() + print(f'Query : {query}') + print(f'Result: {result.strip()}\n') + +NeMo NER Model Under the Hood +----------------------------- + +Any aspect of NLP training or model architecture design can easily be customized with PyTorch Lightning +since every NeMo model is a Lightning Module. + +.. code-block:: python + + class TokenClassificationModel(ModelPT): + """ + Token Classification Model with BERT, applicable for tasks such as Named Entity Recognition + """ + ... + @typecheck() + def forward(self, input_ids, token_type_ids, attention_mask): + hidden_states = self.bert_model( + input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask + ) + logits = self.classifier(hidden_states=hidden_states) + return logits + + # PTL-specfic methods + def training_step(self, batch, batch_idx): + """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ + input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch + logits = self(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) + + loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask) + self.log_dict({'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']}) + return loss + ... + +Neural Types in NeMo NLP +------------------------ + +NeMo Models and Neural Modules come with Neural Type checking. +Neural type checking is extremely useful when combining many different neural network architectures +for a production-grade application. + +.. code-block:: python + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + return self.bert_model.input_types + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return self.classifier.output_types + +-------- + +Text-To-Speech (TTS) +==================== + +Everything needed to train TTS models and generate audio is included with NeMo. +`NeMo TTS Models `_ +can be trained from scratch on your own data or pretrained models can be downloaded +automatically. NeMo currently supports a two step inference procedure. +First, a model is used to generate a mel spectrogram from text. +Second, a model is used to generate audio from a mel spectrogram. + +Mel Spectrogram Generators: + +- `Tacotron 2 `_ +- `Glow-TTS `_ + +Audio Generators: + +- Griffin-Lim +- `WaveGlow `_ +- `SqueezeWave `_ + + +Specify TTS Model Configurations with YAML File +----------------------------------------------- + +.. note:: NeMo Models and PyTorch Lightning Trainer can be fully configured from .yaml files using Hydra. + +`tts/conf/glow_tts.yaml `_ + +.. code-block:: yaml + + # configure the PyTorch Lightning Trainer + trainer: + gpus: -1 # number of gpus + max_epochs: 350 + num_nodes: 1 + distributed_backend: ddp + ... + + # configure the TTS model + model: + ... + encoder: + cls: nemo.collections.tts.modules.glow_tts.TextEncoder + params: + n_vocab: 148 + out_channels: *n_mels + hidden_channels: 192 + filter_channels: 768 + filter_channels_dp: 256 + ... + # all other configuration, data, optimizer, parser, preprocessor, etc + ... + +Developing TTS Model From Scratch +--------------------------------- + +`tts/glow_tts.py `_ + +.. code-block:: python + + # hydra_runner calls hydra.main and is useful for multi-node experiments + @hydra_runner(config_path="conf", config_name="glow_tts") + def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + model = GlowTTSModel(cfg=cfg.model, trainer=trainer) + trainer.fit(model) + +Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trainer, customizable from the command line. + +.. code-block:: bash + + python NeMo/examples/tts/glow_tts.py \ + trainer.gpus=4 \ + trainer.max_epochs=400 \ + ... + train_dataset=/path/to/train/data \ + validation_datasets=/path/to/val/data \ + model.train_ds.batch_size = 64 \ + +.. note:: Training NeMo TTS models from scratch can take days or weeks so it is highly recommended to use multiple GPUs and multiple nodes with the PyTorch Lightning Trainer. + +Using State-Of-The-Art Pre-trained TTS Model +-------------------------------------------- + +Generate speech using models trained on `LJSpeech `, +around 24 hours of single speaker data. + +See this `TTS notebook `_ +for a full tutorial on generating speech with NeMo, PyTorch Lightning, and Hydra. + +.. code-block:: python + + # load pretrained spectrogram model + spec_gen = SpecModel.from_pretrained('GlowTTS-22050Hz').cuda() + + # load pretrained Generators + vocoder = WaveGlowModel.from_pretrained('WaveGlow-22050Hz').cuda() + + def infer(spec_gen_model, vocder_model, str_input): + with torch.no_grad(): + parsed = spec_gen.parse(text_to_generate) + spectrogram = spec_gen.generate_spectrogram(tokens=parsed) + audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram) + if isinstance(spectrogram, torch.Tensor): + spectrogram = spectrogram.to('cpu').numpy() + if len(spectrogram.shape) == 3: + spectrogram = spectrogram[0] + if isinstance(audio, torch.Tensor): + audio = audio.to('cpu').numpy() + return spectrogram, audio + + text_to_generate = input("Input what you want the model to say: ") + spec, audio = infer(spec_gen, vocoder, text_to_generate) + +To see the available pretrained checkpoints: + +.. code-block:: python + + # spec generator + GlowTTSModel.list_available_models() + + # vocoder + WaveGlowModel.list_available_models() + +NeMo TTS Model Under the Hood +----------------------------- + +Any aspect of TTS training or model architecture design can easily +be customized with PyTorch Lightning since every NeMo model is a LightningModule. + +`glow_tts.py `_ + +.. code-block:: python + + class GlowTTSModel(SpectrogramGenerator): + """ + GlowTTS model used to generate spectrograms from text + Consists of a text encoder and an invertible spectrogram decoder + """ + ... + # NeMo models come with neural type checking + @typecheck( + input_types={ + "x": NeuralType(('B', 'T'), TokenIndex()), + "x_lengths": NeuralType(('B'), LengthsType()), + "y": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True), + "y_lengths": NeuralType(('B'), LengthsType(), optional=True), + "gen": NeuralType(optional=True), + "noise_scale": NeuralType(optional=True), + "length_scale": NeuralType(optional=True), + } + ) + def forward(self, *, x, x_lengths, y=None, y_lengths=None, gen=False, noise_scale=0.3, length_scale=1.0): + if gen: + return self.glow_tts.generate_spect( + text=x, text_lengths=x_lengths, noise_scale=noise_scale, length_scale=length_scale + ) + else: + return self.glow_tts(text=x, text_lengths=x_lengths, spect=y, spect_lengths=y_lengths) + ... + def step(self, y, y_lengths, x, x_lengths): + z, y_m, y_logs, logdet, logw, logw_, y_lengths, attn = self( + x=x, x_lengths=x_lengths, y=y, y_lengths=y_lengths, gen=False + ) + + l_mle, l_length, logdet = self.loss( + z=z, + y_m=y_m, + y_logs=y_logs, + logdet=logdet, + logw=logw, + logw_=logw_, + x_lengths=x_lengths, + y_lengths=y_lengths, + ) + + loss = sum([l_mle, l_length]) + + return l_mle, l_length, logdet, loss, attn + + # PTL-specfic methods + def training_step(self, batch, batch_idx): + y, y_lengths, x, x_lengths = batch + + y, y_lengths = self.preprocessor(input_signal=y, length=y_lengths) + + l_mle, l_length, logdet, loss, _ = self.step(y, y_lengths, x, x_lengths) + + output = { + "loss": loss, # required + "progress_bar": {"l_mle": l_mle, "l_length": l_length, "logdet": logdet}, + "log": {"loss": loss, "l_mle": l_mle, "l_length": l_length, "logdet": logdet}, + } + + return output + ... + +Neural Types in NeMo TTS +------------------------ + +NeMo Models and Neural Modules come with Neural Type checking. +Neural type checking is extremely useful when combining many different neural network architectures +for a production-grade application. + +.. code-block:: python + + @typecheck( + input_types={ + "x": NeuralType(('B', 'T'), TokenIndex()), + "x_lengths": NeuralType(('B'), LengthsType()), + "y": NeuralType(('B', 'D', 'T'), MelSpectrogramType(), optional=True), + "y_lengths": NeuralType(('B'), LengthsType(), optional=True), + "gen": NeuralType(optional=True), + "noise_scale": NeuralType(optional=True), + "length_scale": NeuralType(optional=True), + } + ) + def forward(self, *, x, x_lengths, y=None, y_lengths=None, gen=False, noise_scale=0.3, length_scale=1.0): + ... + +-------- + +Learn More +========== + +- Watch the `NVIDIA NeMo Intro Video `_ +- Watch the `PyTorch Lightning and NVIDIA NeMo Discussion Video `_ +- Visit the `NVIDIA NeMo Developer Website `_ +- Read the `NVIDIA NeMo PyTorch Blog `_ +- Download pre-trained `ASR `_, `NLP `_, and `TTS `_ models on `NVIDIA NGC `_ to quickly get started with NeMo. +- Become an expert on Building Conversational AI applications with our `tutorials `_, and `example scripts `_, +- See our `developer guide `_ for more information on core NeMo concepts, ASR/NLP/TTS collections, and the NeMo API. + +.. note:: NeMo tutorial notebooks can be run on `Google Colab `_. + +NVIDIA `NeMo `_ is actively being developed on GitHub. +`Contributions `_ are welcome! diff --git a/docs/source/ecosystem/bolts.rst b/docs/source/ecosystem/bolts.rst new file mode 100644 index 00000000000000..f3a4ab9c858be1 --- /dev/null +++ b/docs/source/ecosystem/bolts.rst @@ -0,0 +1,89 @@ +Bolts +===== +`PyTorch Lightning Bolts `_, is our official collection +of prebuilt models across many research domains. + +.. code-block:: bash + + pip install pytorch-lightning-bolts + +In bolts we have: + +- A collection of pretrained state-of-the-art models. +- A collection of models designed to bootstrap your research. +- A collection of callbacks, transforms, full datasets. +- All models work on CPUs, TPUs, GPUs and 16-bit precision. + +----------------- + +Quality control +--------------- +The Lightning community builds bolts and contributes them to Bolts. +The lightning team guarantees that contributions are: + +- Rigorously Tested (CPUs, GPUs, TPUs). +- Rigorously Documented. +- Standardized via PyTorch Lightning. +- Optimized for speed. +- Checked for correctness. + +--------- + +Example 1: Pretrained, prebuilt models +-------------------------------------- + +.. code-block:: python + + from pl_bolts.models import VAE, GPT2, ImageGPT, PixelCNN + from pl_bolts.models.self_supervised import AMDIM, CPCV2, SimCLR, MocoV2 + from pl_bolts.models import LinearRegression, LogisticRegression + from pl_bolts.models.gans import GAN + from pl_bolts.callbacks import PrintTableMetricsCallback + from pl_bolts.datamodules import FashionMNISTDataModule, CIFAR10DataModule, ImagenetDataModule + +------------ + +Example 2: Extend for faster research +------------------------------------- +Bolts are contributed with benchmarks and continuous-integration tests. This means +you can trust the implementations and use them to bootstrap your research much faster. + +.. code-block:: python + + from pl_bolts.models import ImageGPT + from pl_bolts.self_supervised import SimCLR + + class VideoGPT(ImageGPT): + + def training_step(self, batch, batch_idx): + x, y = batch + x = _shape_input(x) + + logits = self.gpt(x) + simclr_features = self.simclr(x) + + # ----------------- + # do something new with GPT logits + simclr_features + # ----------------- + + loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long()) + + self.log("loss", loss) + return loss + +---------- + +Example 3: Callbacks +-------------------- +We also have a collection of callbacks. + +.. code-block:: python + + from pl_bolts.callbacks import PrintTableMetricsCallback + import pytorch_lightning as pl + + trainer = pl.Trainer(callbacks=[PrintTableMetricsCallback()]) + + # loss│train_loss│val_loss│epoch + # ────────────────────────────── + # 2.2541470527648926│2.2541470527648926│2.2158432006835938│0 diff --git a/docs/source/ecosystem/community_examples.rst b/docs/source/ecosystem/community_examples.rst new file mode 100644 index 00000000000000..9b89b9bb4564fb --- /dev/null +++ b/docs/source/ecosystem/community_examples.rst @@ -0,0 +1,22 @@ +Community Examples +================== + +- `Contextual Emotion Detection (DoubleDistilBert) `_. +- `Cotatron: Transcription-Guided Speech Encoder `_. +- `FasterRCNN object detection + Hydra `_. +- `Image Inpainting using Partial Convolutions `_. +- `MNIST on TPU `_. +- `NER (transformers, TPU) `_. +- `NeuralTexture (CVPR) `_. +- `Recurrent Attentive Neural Process `_. +- `Siamese Nets for One-shot Image Recognition `_. +- `Speech Transformers `_. +- `Transformers transfer learning (Huggingface) `_. +- `Transformers text classification `_. +- `VAE Library of over 18+ VAE flavors `_. +- `Transformers Question Answering (SQuAD) `_. +- `Atlas: End-to-End 3D Scene Reconstruction from Posed Images `_. +- `Self-Supervised Representation Learning (MoCo and BYOL) `_. +- `Pytorch-Forecasting: Time series forecasting package `_. +- `Transformers masked language modeling `_. +- `Pytorch Geometric Examples with Pytorch Lightning and Hydra `_. diff --git a/docs/source/ecosystem/pytorch_ecoystem.rst b/docs/source/ecosystem/pytorch_ecoystem.rst new file mode 100644 index 00000000000000..1be7bb53f8b81c --- /dev/null +++ b/docs/source/ecosystem/pytorch_ecoystem.rst @@ -0,0 +1,4 @@ +Pytorch Ecosystem Examples +========================== + +- `Pytorch Geometric: Deep learning on Graphs and other irregular structures `_. diff --git a/docs/source/experiment_reporting.rst b/docs/source/experiment_reporting.rst deleted file mode 100644 index 8e534f4cc6d265..00000000000000 --- a/docs/source/experiment_reporting.rst +++ /dev/null @@ -1,127 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - - -Experiment Reporting -===================== - -Lightning supports many different experiment loggers. These loggers allow you to monitor losses, images, text, etc... -as training progresses. They usually provide a GUI to visualize and can sometimes even snapshot hyperparameters -used in each experiment. - - -Control logging frequency -^^^^^^^^^^^^^^^^^^^^^^^^^ - -It may slow training down to log every single batch. Trainer has an option to log every k batches instead. - -.. testcode:: - - k = 10 - trainer = Trainer(row_log_interval=k) - -Control log writing frequency -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Writing to a logger can be expensive. In Lightning you can set the interval at which you -want to log using this trainer flag. - -.. seealso:: - :class:`~pytorch_lightning.trainer.trainer.Trainer` - -.. testcode:: - - k = 100 - trainer = Trainer(log_save_interval=k) - -Log metrics -^^^^^^^^^^^ - -To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, TRAINS, etc...) - -1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict. - -.. testcode:: - - def training_epoch_end(self, outputs): - loss = some_loss() - ... - - logs = {'train_loss': loss} - results = {'log': logs} - return results - - def validation_epoch_end(self, outputs): - loss = some_loss() - ... - - logs = {'val_loss': loss} - results = {'log': logs} - return results - - def test_epoch_end(self, outputs): - loss = some_loss() - ... - - logs = {'test_loss': loss} - results = {'log': logs} - return results - -2. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule. -For instance, here we log images using tensorboard. - -.. testcode:: - :skipif: not TORCHVISION_AVAILABLE - - def training_step(self, batch, batch_idx): - self.generated_imgs = self.decoder.generate() - - sample_imgs = self.generated_imgs[:6] - grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image('generated_images', grid, 0) - - ... - return results - -Modify progress bar -^^^^^^^^^^^^^^^^^^^ - -Each return dict from the training_end, validation_end, testing_end and training_step also has -a key called "progress_bar". - -Here we show the validation loss in the progress bar - -.. testcode:: - - def validation_epoch_end(self, outputs): - loss = some_loss() - ... - - logs = {'val_loss': loss} - results = {'progress_bar': logs} - return results - -Snapshot hyperparameters -^^^^^^^^^^^^^^^^^^^^^^^^ -When training a model, it's useful to know what hyperparams went into that model. -When Lightning creates a checkpoint, it stores a key "hparams" with the hyperparams. - -.. code-block:: python - - lightning_checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) - hyperparams = lightning_checkpoint['hparams'] - -Some loggers also allow logging the hyperparams used in the experiment. For instance, -when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show -in the `hparams tab `_. - -Snapshot code -^^^^^^^^^^^^^ -Loggers also allow you to snapshot a copy of the code used in this experiment. -For example, TestTubeLogger does this with a flag: - -.. testcode:: - - from pytorch_lightning.loggers import TestTubeLogger - logger = TestTubeLogger('.', create_git_tag=True) diff --git a/docs/source/extensions/accelerators.rst b/docs/source/extensions/accelerators.rst new file mode 100644 index 00000000000000..f88dc3f2992d6e --- /dev/null +++ b/docs/source/extensions/accelerators.rst @@ -0,0 +1,10 @@ +############ +Accelerators +############ +Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators +also manage distributed accelerators (like DP, DDP, HPC cluster). + +Accelerators can also be configured to run on arbitrary clusters using Plugins or to link up to arbitrary +computational strategies like 16-bit precision via AMP and Apex. + +**For help setting up custom plugin/accelerator please reach out to us at support@pytorchlightning.ai** diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst new file mode 100644 index 00000000000000..73691c6dd76f5c --- /dev/null +++ b/docs/source/extensions/callbacks.rst @@ -0,0 +1,363 @@ +.. testsetup:: * + + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning.callbacks.base import Callback + +.. role:: hidden + :class: hidden-section + +.. _callbacks: + +Callback +======== + +.. raw:: html + + + +| + +A callback is a self-contained program that can be reused across projects. + +Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL +logic that is NOT required for your :doc:`lightning module <../common/lightning_module>` to run. + +Here's the flow of how the callback hooks are executed: + +.. raw:: html + + + +An overall Lightning system should have: + +1. Trainer for all engineering +2. LightningModule for all research code. +3. Callbacks for non-essential code. + +| + +Example: + +.. testcode:: + + from pytorch_lightning.callbacks import Callback + + class MyPrintingCallback(Callback): + + def on_init_start(self, trainer): + print('Starting to init trainer!') + + def on_init_end(self, trainer): + print('trainer is init now') + + def on_train_end(self, trainer, pl_module): + print('do something when training ends') + + trainer = Trainer(callbacks=[MyPrintingCallback()]) + +.. testoutput:: + + Starting to init trainer! + trainer is init now + +We successfully extended functionality without polluting our super clean +:doc:`lightning module <../common/lightning_module>` research code. + +----------- + +Examples +-------- +You can do pretty much anything with callbacks. + +- `Add a MLP to fine-tune self-supervised networks `_. +- `Find how to modify an image input to trick the classification result `_. +- `Interpolate the latent space of any variational model `_. +- `Log images to Tensorboard for any model `_. + + +-------------- + +Built-in Callbacks +------------------ +Lightning has a few built-in callbacks. + +.. note:: + For a richer collection of callbacks, check out our + `bolts library `_. + +.. currentmodule:: pytorch_lightning.callbacks + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + BackboneFinetuning + BaseFinetuning + Callback + EarlyStopping + GPUStatsMonitor + GradientAccumulationScheduler + LambdaCallback + LearningRateMonitor + ModelCheckpoint + ModelPruning + ProgressBar + ProgressBarBase + QuantizationAwareTraining + StochasticWeightAveraging + +---------- + +Persisting State +---------------- + +Some callbacks require internal state in order to function properly. You can optionally +choose to persist your callback's state as part of model checkpoint files using the callback hooks +:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`. +However, you must follow two constraints: + +1. Your returned state must be able to be pickled. +2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class. + + +Best Practices +-------------- +The following are best practices when using/designing callbacks. + +1. Callbacks should be isolated in their functionality. +2. Your callback should not rely on the behavior of other callbacks in order to work properly. +3. Do not manually call methods from the callback. +4. Directly calling methods (eg. `on_validation_end`) is strongly discouraged. +5. Whenever possible, your callbacks should not depend on the order in which they are executed. + +----------- + +.. _hooks: + +Available Callback hooks +------------------------ + +setup +^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.setup + :noindex: + +teardown +^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.teardown + :noindex: + +on_init_start +^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_init_start + :noindex: + +on_init_end +^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_init_end + :noindex: + +on_fit_start +^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint + :noindex: + +on_fit_end +^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_fit_end + :noindex: + +on_sanity_check_start +^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_sanity_check_start + :noindex: + +on_sanity_check_end +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_sanity_check_end + :noindex: + +on_train_batch_start +^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_train_batch_start + :noindex: + +on_train_batch_end +^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_train_batch_end + :noindex: + +on_train_epoch_start +^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_train_epoch_start + :noindex: + +on_train_epoch_end +^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_train_epoch_end + :noindex: + +on_validation_epoch_start +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_epoch_start + :noindex: + +on_validation_epoch_end +^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_epoch_end + :noindex: + +on_test_epoch_start +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_test_epoch_start + :noindex: + +on_test_epoch_end +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_test_epoch_end + :noindex: + +on_epoch_start +^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_start + :noindex: + +on_epoch_end +^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end + :noindex: + +on_batch_start +^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_start + :noindex: + +on_validation_batch_start +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_batch_start + :noindex: + +on_validation_batch_end +^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_batch_end + :noindex: + +on_test_batch_start +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_test_batch_start + :noindex: + +on_test_batch_end +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_test_batch_end + :noindex: + +on_batch_end +^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_end + :noindex: + +on_train_start +^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_train_start + :noindex: + +on_train_end +^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_train_end + :noindex: + +on_pretrain_routine_start +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_start + :noindex: + +on_pretrain_routine_end +^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_pretrain_routine_end + :noindex: + +on_validation_start +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_start + :noindex: + +on_validation_end +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_validation_end + :noindex: + +on_test_start +^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_test_start + :noindex: + +on_test_end +^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_test_end + :noindex: + +on_keyboard_interrupt +^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_keyboard_interrupt + :noindex: + +on_save_checkpoint +^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint + :noindex: + +on_load_checkpoint +^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint + :noindex: + +on_after_backward +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward + :noindex: + +on_before_zero_grad +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad + :noindex: diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst new file mode 100644 index 00000000000000..881febe21316dc --- /dev/null +++ b/docs/source/extensions/datamodules.rst @@ -0,0 +1,431 @@ +.. _datamodules: + +LightningDataModule +=================== +A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data: + +.. raw:: html + + + +| + +A datamodule encapsulates the five steps involved in data processing in PyTorch: + +1. Download / tokenize / process. +2. Clean and (maybe) save to disk. +3. Load inside :class:`~torch.utils.data.Dataset`. +4. Apply transforms (rotate, tokenize, etc...). +5. Wrap inside a :class:`~torch.utils.data.DataLoader`. + +| + +This class can then be shared and used anywhere: + +.. code-block:: python + + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule + + model = LitClassifier() + trainer = Trainer() + + imagenet = ImagenetDataModule() + trainer.fit(model, imagenet) + + cifar10 = CIFAR10DataModule() + trainer.fit(model, cifar10) + +--------------- + +Why do I need a DataModule? +--------------------------- +In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes +sharing and reusing the exact splits and transforms across projects impossible. + +Datamodules are for you if you ever asked the questions: + +- what splits did you use? +- what transforms did you use? +- what normalization did you use? +- how did you prepare/tokenize the data? + +-------------- + +What is a DataModule +-------------------- +A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the +matching transforms and data processing/downloads steps required. + +Here's a simple PyTorch example: + +.. code-block:: python + + # regular PyTorch + test_data = MNIST(my_path, train=False, download=True) + train_data = MNIST(my_path, train=True, download=True) + train_data, val_data = random_split(train_data, [55000, 5000]) + + train_loader = DataLoader(train_data, batch_size=32) + val_loader = DataLoader(val_data, batch_size=32) + test_loader = DataLoader(test_data, batch_size=32) + +The equivalent DataModule just organizes the same exact code, but makes it reusable across projects. + +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + + def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + + def setup(self, stage: Optional[str] = None): + self.mnist_test = MNIST(self.data_dir, train=False) + mnist_full = MNIST(self.data_dir, train=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size) + + def teardown(self, stage: Optional[str] = None): + # Used to clean-up when the run is finished + ... + +But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can +let Lightning handle those details for you while making this dataset reusable so you can share with +colleagues or use in different projects. + +.. code-block:: python + + mnist = MNISTDataModule(my_path) + model = LitClassifier() + + trainer = Trainer() + trainer.fit(model, mnist) + +Here's a more realistic, complex DataModule that shows how much more reusable the datamodule is. + +.. code-block:: python + + import pytorch_lightning as pl + from torch.utils.data import random_split, DataLoader + + # Note - you must have torchvision installed for this example + from torchvision.datasets import MNIST + from torchvision import transforms + + + class MNISTDataModule(pl.LightningDataModule): + + def __init__(self, data_dir: str = './'): + super().__init__() + self.data_dir = data_dir + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + # self.dims is returned when you call dm.size() + # Setting default dims here because we know them. + # Could optionally be assigned dynamically in dm.setup() + self.dims = (1, 28, 28) + + def prepare_data(self): + # download + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage: Optional[str] = None): + + # Assign train/val datasets for use in dataloaders + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + + # Optionally... + # self.dims = tuple(self.mnist_train[0][0].shape) + + # Assign test dataset for use in dataloader(s) + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) + + # Optionally... + # self.dims = tuple(self.mnist_test[0][0].shape) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=32) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=32) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=32) + + +.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. + + +--------------- + +LightningDataModule API +----------------------- +To define a DataModule define 5 methods: + +- prepare_data (how to download(), tokenize, etc...) +- setup (how to split, etc...) +- train_dataloader +- val_dataloader(s) +- test_dataloader(s) + + +prepare_data +^^^^^^^^^^^^ +Use this method to do things that might write to disk or that need to be done only from a single process in distributed +settings. + +- download +- tokenize +- etc... + +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + def prepare_data(self): + # download + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) + + +.. warning:: ``prepare_data`` is called from a single process (e.g. GPU 0). Do not use it to assign state (`self.x = y`). + + +setup +^^^^^ +There are also data operations you might want to perform on every GPU. Use setup to do things like: + +- count number of classes +- build vocabulary +- perform train/val/test splits +- apply transforms (defined explicitly in your datamodule or assigned in init) +- etc... + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + + def setup(self, stage: Optional[str] = None): + + # Assign Train/val split(s) for use in Dataloaders + if stage == 'fit' or stage is None: + mnist_full = MNIST( + self.data_dir, + train=True, + download=True, + transform=self.transform + ) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.dims = self.mnist_train[0][0].shape + + # Assign Test split(s) for use in Dataloaders + if stage == 'test' or stage is None: + self.mnist_test = MNIST( + self.data_dir, + train=False, + download=True, + transform=self.transform + ) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + + +.. warning:: ``setup`` is called from every process. Setting state here is okay. + + +.. note:: ``teardown`` can be used to clean up the state. It is also called from every process + + +train_dataloader +^^^^^^^^^^^^^^^^ +Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in ``setup``. + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=64) + + +val_dataloader +^^^^^^^^^^^^^^ +Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in ``setup``. + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=64) + + +.. _datamodule-test-dataloader-label: + +test_dataloader +^^^^^^^^^^^^^^^ +Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``. + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=64) + + +transfer_batch_to_device +^^^^^^^^^^^^^^^^^^^^^^^^ +Override to define how you want to move an arbitrary batch to a device. + +.. testcode:: + + class MNISTDataModule(LightningDataModule): + def transfer_batch_to_device(self, batch, device): + x = batch['x'] + x = CustomDataWrapper(x) + batch['x'] = x.to(device) + return batch + + +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). + + +on_before_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply augmentations to your batch before it is transferred to the device. + +.. testcode:: + + class MNISTDataModule(LightningDataModule): + def on_before_batch_transfer(self, batch, dataloader_idx): + batch['x'] = transforms(batch['x']) + return batch + + +.. warning:: + Currently dataloader_idx always returns 0 and will be updated to support the true idx in the future. + +.. note:: This hook only runs on single GPU training and DDP (no data-parallel). + + +on_after_batch_transfer +^^^^^^^^^^^^^^^^^^^^^^^ +Override to alter or apply augmentations to your batch after it is transferred to the device. + +.. testcode:: + + class MNISTDataModule(LightningDataModule): + def on_after_batch_transfer(self, batch, dataloader_idx): + batch['x'] = gpu_transforms(batch['x']) + return batch + + +.. warning:: + + Currently ``dataloader_idx`` always returns 0 and will be updated to support the true ``idx`` in the future. + +.. note:: + This hook only runs on single GPU training and DDP (no data-parallel). This hook + will also be called when using CPU device, so adding augmentations here or in + ``on_before_batch_transfer`` means the same thing. + + + +.. note:: To decouple your data from transforms you can parametrize them via ``__init__``. + +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + def __init__(self, train_transforms, val_transforms, test_transforms): + super().__init__() + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.test_transforms = test_transforms + + +------------------ + +Using a DataModule +------------------ + +The recommended way to use a DataModule is simply: + +.. code-block:: python + + dm = MNISTDataModule() + model = Model() + trainer.fit(model, dm) + + trainer.test(datamodule=dm) + +If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning +still ensures the method runs on the correct devices) + +.. code-block:: python + + dm = MNISTDataModule() + dm.prepare_data() + dm.setup(stage='fit') + + model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) + trainer.fit(model, dm) + + dm.setup(stage='test') + trainer.test(datamodule=dm) + +---------------- + +Datamodules without Lightning +----------------------------- +You can of course use DataModules in plain PyTorch code as well. + +.. code-block:: python + + # download, etc... + dm = MNISTDataModule() + dm.prepare_data() + + # splits/transforms + dm.setup(stage='fit') + + # use data + for batch in dm.train_dataloader(): + ... + for batch in dm.val_dataloader(): + ... + + dm.teardown(stage='fit') + + # lazy load test data + dm.setup(stage='test') + for batch in dm.test_dataloader(): + ... + + dm.teardown(stage='test') + +But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified +structure. diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst new file mode 100644 index 00000000000000..8d782a9c478c92 --- /dev/null +++ b/docs/source/extensions/logging.rst @@ -0,0 +1,361 @@ +.. testsetup:: * + + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + from pytorch_lightning import loggers as pl_loggers + +.. role:: hidden + :class: hidden-section + +.. _logging: + + +####### +Logging +####### + +Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...). +To use a logger, simply pass it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`. +Lightning uses TensorBoard by default. + +.. testcode:: + + from pytorch_lightning import loggers as pl_loggers + + tb_logger = pl_loggers.TensorBoardLogger('logs/') + trainer = Trainer(logger=tb_logger) + +Choose from any of the others such as MLflow, Comet, Neptune, WandB, ... + +.. testcode:: + + comet_logger = pl_loggers.CometLogger(save_dir='logs/') + trainer = Trainer(logger=comet_logger) + +To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers ... + +.. testcode:: + + tb_logger = pl_loggers.TensorBoardLogger('logs/') + comet_logger = pl_loggers.CometLogger(save_dir='logs/') + trainer = Trainer(logger=[tb_logger, comet_logger]) + +.. note:: + + By default, lightning logs every 50 steps. Use Trainer flags to :ref:`logging_frequency`. + +.. note:: + + All loggers log by default to `os.getcwd()`. To change the path without creating a logger set + `Trainer(default_root_dir='/your/path/to/save/checkpoints')` + +---------- + +****************************** +Logging from a LightningModule +****************************** + +Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else. + +Automatic Logging +================= +Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` +method to log from anywhere in a :doc:`lightning module <../common/lightning_module>` and :doc:`callbacks <../extensions/callbacks>` +except functions with `batch_start` in their names. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_metric', x) + +Depending on where log is called from, Lightning auto-determines the correct logging mode for you. \ +But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + +The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options: + +* `on_step`: Logs the metric at the current step. Defaults to `True` in :func:`~~pytorch_lightning.core.lightning.LightningModule.training_step`, and :func:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`. + +* `on_epoch`: Automatically accumulates and logs at the end of the epoch. Defaults to True anywhere in validation or test loops, and in :func:`~~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`. + +* `prog_bar`: Logs to the progress bar. + +* `logger`: Logs to the logger like Tensorboard, or any other custom logger passed to the :class:`~pytorch_lightning.trainer.trainer.Trainer`. + + +.. note:: + + - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a + reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + + - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with + suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` + argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice. + + +If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you. + + +Manual logging +============== +If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly. + +.. code-block:: python + + def training_step(...): + ... + # the logger you used (in this case tensorboard) + tensorboard = self.logger.experiment + tensorboard.add_image() + tensorboard.add_histogram(...) + tensorboard.add_figure(...) + + +Access your logs +================ +Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs: + +.. code-block:: bash + + tensorboard --logdir ./lightning_logs + +---------- + +******************** +Make a custom logger +******************** + +You can implement your own logger by writing a class that inherits from :class:`~pytorch_lightning.loggers.base.LightningLoggerBase`. +Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively. + +.. testcode:: + + from pytorch_lightning.utilities import rank_zero_only + from pytorch_lightning.loggers import LightningLoggerBase + from pytorch_lightning.loggers.base import rank_zero_experiment + + class MyLogger(LightningLoggerBase): + + @property + def name(self): + return 'MyLogger' + + @property + @rank_zero_experiment + def experiment(self): + # Return the experiment object associated with this logger. + pass + + @property + def version(self): + # Return the experiment version, int or str. + return '0.1' + + @rank_zero_only + def log_hyperparams(self, params): + # params is an argparse.Namespace + # your code to record hyperparameters goes here + pass + + @rank_zero_only + def log_metrics(self, metrics, step): + # metrics is a dictionary of metric names and values + # your code to record metrics goes here + pass + + @rank_zero_only + def save(self): + # Optional. Any code necessary to save logger data goes here + # If you implement this, remember to call `super().save()` + # at the start of the method (important for aggregation of metrics) + super().save() + + @rank_zero_only + def finalize(self, status): + # Optional. Any code that needs to be run after training + # finishes goes here + pass + +If you write a logger that may be useful to others, please send +a pull request to add it to Lightning! + +---------- + +.. _logging_frequency: + + +************************* +Control logging frequency +************************* + +Logging frequency +================= + +It may slow training down to log every single batch. By default, Lightning logs every 50 rows, or 50 training steps. +To change this behaviour, set the `log_every_n_steps` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag. + +.. testcode:: + + k = 10 + trainer = Trainer(log_every_n_steps=k) + + + +Log writing frequency +===================== + +Writing to a logger can be expensive, so by default Lightning write logs to disc or to the given logger every 100 training steps. +To change this behaviour, set the interval at which you wish to flush logs to the filesystem using `log_every_n_steps` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag. + +.. testcode:: + + k = 100 + trainer = Trainer(flush_logs_every_n_steps=k) + +Unlike the `log_every_n_steps`, this argument does not apply to all loggers. +The example shown here works with :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`, +which is the default logger in Lightning. + +---------- + +************ +Progress Bar +************ +You can add any metric to the progress bar using :func:`~~pytorch_lightning.core.lightning.LightningModule.log` +method, setting `prog_bar=True`. + + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_loss', loss, prog_bar=True) + + +Modifying the progress bar +========================== + +The progress bar by default already includes the training loss and version number of the experiment +if you are using a logger. These defaults can be customized by overriding the +:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module. + +.. code-block:: python + + def get_progress_bar_dict(self): + # don't show the version number + items = super().get_progress_bar_dict() + items.pop("v_num", None) + return items + + +---------- + + +************************* +Configure console logging +************************* + +Lightning logs useful information about the training process and user warnings to the console. +You can retrieve the Lightning logger and change it to your liking. For example, adjust the logging level +or redirect output for certain modules to log files: + +.. testcode:: + + import logging + + # configure logging at the root level of lightning + logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) + + # configure logging on module level, redirect to file + logger = logging.getLogger("pytorch_lightning.core") + logger.addHandler(logging.FileHandler("core.log")) + +Read more about custom Python logging `here `_. + + +---------- + +*********************** +Logging hyperparameters +*********************** + +When training a model, it's useful to know what hyperparams went into that model. +When Lightning creates a checkpoint, it stores a key "hyper_parameters" with the hyperparams. + +.. code-block:: python + + lightning_checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) + hyperparams = lightning_checkpoint['hyper_parameters'] + +Some loggers also allow logging the hyperparams used in the experiment. For instance, +when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show +in the `hparams tab `_. + +.. note:: + If you want to track a metric in the tensorboard hparams tab, log scalars to the key ``hp_metric``. If tracking multiple metrics, initialize ``TensorBoardLogger`` with ``default_hp_metric=False`` and call ``log_hyperparams`` only once with your metric keys and initial values. Subsequent updates can simply be logged to the metric keys. Refer to the following for examples on how to setup proper hyperparams metrics tracking within :doc:`LightningModule <../common/lightning_module>`. + + .. code-block:: python + + # Using default_hp_metric + def validation_step(self, batch, batch_idx): + self.log("hp_metric", some_scalar) + + # Using custom or multiple metrics (default_hp_metric=False) + def on_train_start(self): + self.logger.log_hyperparams(self.hparams, {"hp/metric_1": 0, "hp/metric_2": 0}) + + def validation_step(self, batch, batch_idx): + self.log("hp/metric_1", some_scalar_1) + self.log("hp/metric_2", some_scalar_2) + + In the example, using `hp/` as a prefix allows for the metrics to be grouped under "hp" in the tensorboard scalar tab where you can collapse them. + +---------- + +************* +Snapshot code +************* + +Loggers also allow you to snapshot a copy of the code used in this experiment. +For example, TestTubeLogger does this with a flag: + +.. code-block:: python + + from pytorch_lightning.loggers import TestTubeLogger + logger = TestTubeLogger('.', create_git_tag=True) + +---------- + +***************** +Supported Loggers +***************** + +The following are loggers we support + +.. note:: + The following loggers will normally plot an additional chart (**global_step VS epoch**). + +.. note:: + postfix ``_step`` and ``_epoch`` will be appended to the name you logged + if ``on_step`` and ``on_epoch`` are set to ``True`` in ``self.log()``. + +.. note:: + Depending on the loggers you use, there might be some additional charts. + +.. currentmodule:: pytorch_lightning.loggers + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + CometLogger + CSVLogger + MLFlowLogger + NeptuneLogger + TensorBoardLogger + TestTubeLogger + WandbLogger diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst new file mode 100644 index 00000000000000..74a4a15deb2be2 --- /dev/null +++ b/docs/source/extensions/metrics.rst @@ -0,0 +1,9 @@ +####### +Metrics +####### + +``pytorch_lightning.metrics`` has been moved to a separate package `TorchMetrics `_. +We will preserve compatibility for the next few releases, nevertheless, we encourage users to update to use this stand-alone package. + +.. warning:: + ``pytorch_lightning.metrics`` is deprecated from v1.3 and will be removed in v1.5. diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst new file mode 100644 index 00000000000000..7f2c904e6c59c0 --- /dev/null +++ b/docs/source/extensions/plugins.rst @@ -0,0 +1,7 @@ +####### +Plugins +####### + +Plugins allow custom integrations to the internals of the Trainer such as a custom amp or ddp implementation. + +**For help setting up custom plugin/accelerator please reach out to us at support@pytorchlightning.ai** diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 82909d14c4d740..873e7ea0fd486b 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -1,4 +1,6 @@ -Pytorch Lightning Governance | Persons of interest +.. _governance: + +PyTorch Lightning Governance | Persons of interest ================================================== Leads @@ -16,3 +18,12 @@ Core Maintainers - Jeremy Jordan (`jeremyjordan `_) - Tullie Murrell (`tullie `_) - Adrian Wälchli (`awaelchli `_) +- Nicki Skafte (`skaftenicki `_) +- Peter Yu (`yukw777 `_) +- Rohit Gupta (`rohitgr7 `_) +- Jeff Yang (`ydcjeff `_) +- Roger Shieh (`s-rog `_) +- Carlos Mocholí (`carmocca `_) +- Ananth Subramaniam (`ananthsub `_) +- Thomas Chaton (`tchaton `_) +- Sean Narenthiran (`SeanNaren `_) diff --git a/docs/source/hooks.rst b/docs/source/hooks.rst deleted file mode 100644 index 18bfb028d44067..00000000000000 --- a/docs/source/hooks.rst +++ /dev/null @@ -1,77 +0,0 @@ -Model Hooks -=========== - -There are cases when you might want to do something different at different parts of the training/validation loop. -To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time. - -**Contributing** If there's a hook you'd like to add, simply: - -1. Fork `PyTorchLightning `_. - -2. Add the hook to :class:`pytorch_lightning.core.hooks.ModelHooks`. - -3. Add it in the correct place in :mod:`pytorch_lightning.trainer` where it should be called. - - -Hooks lifecycle ---------------- - -Training set-up -^^^^^^^^^^^^^^^ - -- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection` -- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_dataloader` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.val_dataloader` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.summarize` -- :meth:`~pytorch_lightning.trainer.training_io.TrainerIOMixin.restore_weights` - -Training loop -^^^^^^^^^^^^^ - -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_start` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_start` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.tbptt_split_batch` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end` (optional) -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.backward` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_after_backward` -- ``optimizer.step()`` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_batch_end` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_epoch_end` - -Validation loop -^^^^^^^^^^^^^^^ - -- ``model.zero_grad()`` -- ``model.eval()`` -- ``torch.set_grad_enabled(False)`` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step_end` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` -- ``model.train()`` -- ``torch.set_grad_enabled(True)`` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check` - -Test loop -^^^^^^^^^ - -- ``model.zero_grad()`` -- ``model.eval()`` -- ``torch.set_grad_enabled(False)`` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step_end` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end` -- ``model.train()`` -- ``torch.set_grad_enabled(True)`` -- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check` - - - -.. automodule:: pytorch_lightning.core.hooks - :noindex: \ No newline at end of file diff --git a/docs/source/hyperparameters.rst b/docs/source/hyperparameters.rst deleted file mode 100644 index 5b2dd343fb6225..00000000000000 --- a/docs/source/hyperparameters.rst +++ /dev/null @@ -1,249 +0,0 @@ -.. testsetup:: * - - import torch - from argparse import ArgumentParser, Namespace - from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.core.lightning import LightningModule - import sys - sys.argv = ['foo'] - - -Hyperparameters ---------------- -Lightning has utilities to interact seamlessly with the command line ArgumentParser -and plays well with the hyperparameter optimization framework of your choice. - -ArgumentParser -^^^^^^^^^^^^^^ -Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser - -.. testcode:: - - from argparse import ArgumentParser - parser = ArgumentParser() - parser.add_argument('--layer_1_dim', type=int, default=128) - args = parser.parse_args() - -This allows you to call your program like so: - -.. code-block:: bash - - python trainer.py --layer_1_dim 64 - - -Argparser Best Practices -^^^^^^^^^^^^^^^^^^^^^^^^ -It is best practice to layer your arguments in three sections. - -1. Trainer args (gpus, num_nodes, etc...) -2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...) -3. Program arguments (data_path, cluster_email, etc...) - -We can do this as follows. First, in your LightningModule, define the arguments -specific to that module. Remember that data splits or data paths may also be specific to -a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10). - -.. testcode:: - - class LitModel(LightningModule): - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--encoder_layers', type=int, default=12) - parser.add_argument('--data_path', type=str, default='/some/path') - return parser - -Now in your main trainer file, add the Trainer args, the program args, and add the model args - -.. testcode:: - - # ---------------- - # trainer_main.py - # ---------------- - from argparse import ArgumentParser - parser = ArgumentParser() - - # add PROGRAM level args - parser.add_argument('--conda_env', type=str, default='some_name') - parser.add_argument('--notification_email', type=str, default='will@email.com') - - # add model specific args - parser = LitModel.add_model_specific_args(parser) - - # add all the available trainer options to argparse - # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli - parser = Trainer.add_argparse_args(parser) - - hparams = parser.parse_args() - -Now you can call run your program like so - -.. code-block:: bash - - python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12 - -Finally, make sure to start the training like so: - -.. code-block:: python - - # YES - model = LitModel(hparams) - trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...) - - # NO - # model = LitModel(learning_rate=hparams.learning_rate, ...) - # trainer = Trainer(gpus=hparams.gpus, ...) - -LightningModule hparams -^^^^^^^^^^^^^^^^^^^^^^^ - -Normally, we don't hard-code the values to a model. We usually use the command line to -modify the network and read those values in the LightningModule - -.. testcode:: - - class LitMNIST(LightningModule): - - def __init__(self, hparams): - super().__init__() - - # do this to save all arguments in any logger (tensorboard) - self.hparams = hparams - - self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim) - self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim) - self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10) - - def train_dataloader(self): - return DataLoader(mnist_train, batch_size=self.hparams.batch_size) - - def configure_optimizers(self): - return Adam(self.parameters(), lr=self.hparams.learning_rate) - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--layer_1_dim', type=int, default=128) - parser.add_argument('--layer_2_dim', type=int, default=256) - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument('--learning_rate', type=float, default=0.002) - return parser - -Now pass in the params when you init your model - -.. code-block:: python - - parser = ArgumentParser() - parser = LitMNIST.add_model_specific_args(parser) - hparams = parser.parse_args() - model = LitMNIST(hparams) - -The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule. -This does two things: - -1. It adds them automatically to TensorBoard logs under the hparams tab. -2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly. - -Trainer args -^^^^^^^^^^^^ -To recap, add ALL possible trainer flags to the argparser and init the Trainer this way - -.. code-block:: python - - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - hparams = parser.parse_args() - - trainer = Trainer.from_argparse_args(hparams) - - # or if you need to pass in callbacks - trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...]) - - -Multiple Lightning Modules -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We often have multiple Lightning Modules where each one has different arguments. Instead of -polluting the main.py file, the LightningModule lets you define arguments for each one. - -.. testcode:: - - class LitMNIST(LightningModule): - - def __init__(self, hparams): - super().__init__() - self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim) - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser]) - parser.add_argument('--layer_1_dim', type=int, default=128) - return parser - -.. testcode:: - - class GoodGAN(LightningModule): - - def __init__(self, hparams): - super().__init__() - self.encoder = Encoder(layers=hparams.encoder_layers) - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser]) - parser.add_argument('--encoder_layers', type=int, default=12) - return parser - - -Now we can allow each model to inject the arguments it needs in the ``main.py`` - -.. code-block:: python - - def main(args): - - # pick model - if args.model_name == 'gan': - model = GoodGAN(hparams=args) - elif args.model_name == 'mnist': - model = LitMNIST(hparams=args) - - model = LitMNIST(hparams=args) - trainer = Trainer.from_argparse_args(args) - trainer.fit(model) - - if __name__ == '__main__': - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - - # figure out which model to use - parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist') - - # THIS LINE IS KEY TO PULL THE MODEL NAME - temp_args, _ = parser.parse_known_args() - - # let the model add what it wants - if temp_args.model_name == 'gan': - parser = GoodGAN.add_model_specific_args(parser) - elif temp_args.model_name == 'mnist': - parser = LitMNIST.add_model_specific_args(parser) - - args = parser.parse_args() - - # train - main(args) - -and now we can train MNIST or the GAN using the command line interface! - -.. code-block:: bash - - $ python main.py --model_name gan --encoder_layers 24 - $ python main.py --model_name mnist --layer_1_dim 128 - -Hyperparameter Optimization -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Lightning is fully compatible with the hyperparameter optimization libraries! -Here are some useful ones: - -- `Hydra `_ -- `Optuna `_ diff --git a/docs/source/index.rst b/docs/source/index.rst index b74a9490af4e0e..81011cbf14724b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -6,78 +6,124 @@ PyTorch Lightning Documentation =============================== + .. toctree:: :maxdepth: 1 :name: start - :caption: Start Here + :caption: Getting started + + starter/new-project + starter/converting + starter/rapid_prototyping_templates + +.. toctree:: + :maxdepth: 1 + :name: guides + :caption: Best practices + + starter/style_guide + benchmarking/performance + Lightning project template + benchmarking/benchmarks - new-project - introduction_guide .. toctree:: :maxdepth: 2 - :name: docs - :caption: Python API + :name: pl_docs + :caption: Lightning API - callbacks - hooks - lightning-module - loggers - trainer + common/lightning_module + common/trainer .. toctree:: - :maxdepth: 1 - :name: Community Examples - :caption: Community Examples - - Contextual Emotion Detection (DoubleDistilBert) - Generative Adversarial Network - Hyperparameter optimization with Optuna - Image Inpainting using Partial Convolutions - MNIST on TPU - NER (transformers, TPU) - NeuralTexture (CVPR) - Recurrent Attentive Neural Process - Siamese Nets for One-shot Image Recognition - Speech Transformers - Transformers transfer learning (Huggingface) - Transformers text classification - VAE Library of over 18+ VAE flavors + :maxdepth: 2 + :name: docs + :caption: Optional extensions + + extensions/accelerators + extensions/callbacks + extensions/datamodules + extensions/logging + extensions/metrics + extensions/plugins + .. toctree:: :maxdepth: 1 :name: Tutorials :caption: Tutorials - From PyTorch to PyTorch Lightning + starter/introduction_guide + PyTorch Lightning 101 class + From PyTorch to PyTorch Lightning [Blog] + From PyTorch to PyTorch Lightning [Video] + +.. toctree:: + :maxdepth: 2 + :name: api + :caption: API References + + api_references + +.. toctree:: + :maxdepth: 1 + :name: Bolts + :caption: Bolts + + ecosystem/bolts + +.. toctree:: + :maxdepth: 1 + :name: Examples + :caption: Examples + + ecosystem/pytorch_ecoystem + ecosystem/community_examples + Autoencoder + BYOL + DQN + GAN + GPT-2 + Image-GPT + SimCLR + VAE .. toctree:: :maxdepth: 1 :name: Common Use Cases :caption: Common Use Cases - apex - slurm - child_modules - debugging - experiment_logging - experiment_reporting - early_stopping - fast_training - hooks - hyperparameters - lr_finder - multi_gpu - multiple_loaders - weights_loading - optimizers - profiler - single_gpu - sequences - training_tricks - transfer_learning - tpu - test_set + clouds/cloud_training + advanced/amp + clouds/slurm + common/child_modules + common/debugging + common/loggers + common/early_stopping + common/fast_training + common/hyperparameters + advanced/lr_finder + advanced/multi_gpu + advanced/multiple_loaders + common/weights_loading + common/optimizers + advanced/profiler + common/single_gpu + advanced/sequences + advanced/training_tricks + advanced/pruning_quantization + advanced/transfer_learning + advanced/tpu + advanced/cluster + common/test_set + common/production_inference + +.. toctree:: + :maxdepth: 1 + :name: Partner Domain Frameworks + :caption: Partner Domain Frameworks + + ecosystem/asr_nlp_tts .. toctree:: :maxdepth: 1 @@ -85,29 +131,14 @@ PyTorch Lightning Documentation :caption: Community - CODE_OF_CONDUCT.md - CONTRIBUTING.md - BECOMING_A_CORE_CONTRIBUTOR.md - PULL_REQUEST_TEMPLATE.md + generated/CODE_OF_CONDUCT.md + generated/CONTRIBUTING.md + generated/BECOMING_A_CORE_CONTRIBUTOR.md governance.md + generated/CHANGELOG.md Indices and tables ------------------ * :ref:`genindex` -* :ref:`modindex` * :ref:`search` - - - -.. This is here to make sphinx aware of the modules but not throw an error/warning -.. toctree:: - :hidden: - - api/pytorch_lightning.core - api/pytorch_lightning.callbacks - api/pytorch_lightning.loggers - api/pytorch_lightning.overrides - api/pytorch_lightning.profiler - api/pytorch_lightning.trainer - api/pytorch_lightning.utilities \ No newline at end of file diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst deleted file mode 100644 index 5d26278483c398..00000000000000 --- a/docs/source/introduction_guide.rst +++ /dev/null @@ -1,984 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.core.lightning import LightningModule - from pytorch_lightning.trainer.trainer import Trainer - - -Introduction Guide -================== -PyTorch Lightning provides a very simple template for organizing your PyTorch code. Once -you've organized it into a LightningModule, it automates most of the training for you. - -To illustrate, here's the typical PyTorch project structure organized in a LightningModule. - -.. figure:: /_images/mnist_imgs/pt_to_pl.jpg - :alt: Convert from PyTorch to Lightning - -As your project grows in complexity with things like 16-bit precision, distributed training, etc... the part in blue -quickly becomes onerous and starts distracting from the core research code. - ---------- - -Goal of this guide ------------------- -This guide walks through the major parts of the library to help you understand -what each parts does. But at the end of the day, you write the same PyTorch code... just organize it -into the LightningModule template which means you keep ALL the flexibility without having to deal with -any of the boilerplate code - -To show how Lightning works, we'll start with an MNIST classifier. We'll end showing how -to use inheritance to very quickly create an AutoEncoder. - -.. note:: Any DL/ML PyTorch project fits into the Lightning structure. Here we just focus on 3 types - of research to illustrate. - ---------- - -Installing Lightning --------------------- -Lightning is trivial to install. - -.. code-block:: bash - - conda activate my_env - pip install pytorch-lightning - -Or without conda environments, anywhere you can use pip. - -.. code-block:: bash - - pip install pytorch-lightning - ---------- - -Lightning Philosophy --------------------- -Lightning factors DL/ML code into three types: - -- Research code -- Engineering code -- Non-essential code - -Research code -^^^^^^^^^^^^^ -In the MNIST generation example, the research code would be the particular system and how it's trained (ie: A GAN or VAE). -In Lightning, this code is abstracted out by the `LightningModule`. - -.. code-block:: python - - l1 = nn.Linear(...) - l2 = nn.Linear(...) - decoder = Decoder() - - x1 = l1(x) - x2 = l2(x2) - out = decoder(features, x) - - loss = perceptual_loss(x1, x2, x) + CE(out, x) - -Engineering code -^^^^^^^^^^^^^^^^ - -The Engineering code is all the code related to training this system. Things such as early stopping, distribution -over GPUs, 16-bit precision, etc. This is normally code that is THE SAME across most projects. - -In Lightning, this code is abstracted out by the `Trainer`. - -.. code-block:: python - - model.cuda(0) - x = x.cuda(0) - - distributed = DistributedParallel(model) - - with gpu_zero: - download_data() - - dist.barrier() - -Non-essential code -^^^^^^^^^^^^^^^^^^ -This is code that helps the research but isn't relevant to the research code. Some examples might be: -1. Inspect gradients -2. Log to tensorboard. - -In Lightning this code is abstracted out by `Callbacks`. - -.. code-block:: python - - # log samples - z = Q.rsample() - generated = decoder(z) - self.experiment.log('images', generated) - ---------- - -Elements of a research project ------------------------------- -Every research project requires the same core ingredients: - -1. A model -2. Train/val/test data -3. Optimizer(s) -4. Training step computations -5. Validation step computations -6. Test step computations - - -The Model -^^^^^^^^^ -The LightningModule provides the structure on how to organize these 5 ingredients. - -Let's first start with the model. In this case we'll design -a 3-layer neural network. - -.. testcode:: - - import torch - from torch.nn import functional as F - from torch import nn - from pytorch_lightning.core.lightning import LightningModule - - class LitMNIST(LightningModule): - - def __init__(self): - super().__init__() - - # mnist images are (1, 28, 28) (channels, width, height) - self.layer_1 = torch.nn.Linear(28 * 28, 128) - self.layer_2 = torch.nn.Linear(128, 256) - self.layer_3 = torch.nn.Linear(256, 10) - - def forward(self, x): - batch_size, channels, width, height = x.size() - - # (b, 1, 28, 28) -> (b, 1*28*28) - x = x.view(batch_size, -1) - - # layer 1 - x = self.layer_1(x) - x = torch.relu(x) - - # layer 2 - x = self.layer_2(x) - x = torch.relu(x) - - # layer 3 - x = self.layer_3(x) - - # probability distribution over labels - x = torch.log_softmax(x, dim=1) - - return x - -Notice this is a `LightningModule` instead of a `torch.nn.Module`. A LightningModule is -equivalent to a PyTorch Module except it has added functionality. However, you can use it -EXACTLY the same as you would a PyTorch Module. - -.. testcode:: - - net = LitMNIST() - x = torch.Tensor(1, 1, 28, 28) - out = net(x) - -.. rst-class:: sphx-glr-script-out - - Out: - - .. code-block:: none - - torch.Size([1, 10]) - -Data -^^^^ - -The Lightning Module organizes your dataloaders and data processing as well. -Here's the PyTorch code for loading MNIST - -.. testcode:: - :skipif: not TORCHVISION_AVAILABLE - - from torch.utils.data import DataLoader, random_split - from torchvision.datasets import MNIST - import os - from torchvision import datasets, transforms - - # transforms - # prepare transforms standard to MNIST - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - - # data - mnist_train = MNIST(os.getcwd(), train=True, download=True) - mnist_train = DataLoader(mnist_train, batch_size=64) - -.. testoutput:: - :hide: - :skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not TORCHVISION_AVAILABLE - - Downloading ... - Extracting ... - Downloading ... - Extracting ... - Downloading ... - Extracting ... - Processing... - Done! - -When using PyTorch Lightning, we use the exact same code except we organize it into -the LightningModule - -.. testcode:: - :skipif: not TORCHVISION_AVAILABLE - - from torch.utils.data import DataLoader, random_split - from torchvision.datasets import MNIST - import os - from torchvision import datasets, transforms - - class LitMNIST(LightningModule): - - def train_dataloader(self): - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - mnist_train = MNIST(os.getcwd(), train=True, download=False, - transform=transform) - return DataLoader(mnist_train, batch_size=64) - -Notice the code is exactly the same, except now the training dataloading has been organized by the LightningModule -under the `train_dataloader` method. This is great because if you run into a project that uses Lightning and want -to figure out how they prepare their training data you can just look in the `train_dataloader` method. - -Usually though, we want to separate the things that write to disk in data-processing from -things like transforms which happen in memory. - -.. testcode:: - - class LitMNIST(LightningModule): - - def prepare_data(self): - # download only - MNIST(os.getcwd(), train=True, download=True) - - def train_dataloader(self): - # no download, just transform - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - mnist_train = MNIST(os.getcwd(), train=True, download=False, - transform=transform) - return DataLoader(mnist_train, batch_size=64) - -Doing it in the `prepare_data` method ensures that when you have -multiple GPUs you won't overwrite the data. This is a contrived example -but it gets more complicated with things like NLP or Imagenet. - -In general fill these methods with the following: - -.. testcode:: - - class LitMNIST(LightningModule): - - def prepare_data(self): - # stuff here is done once at the very beginning of training - # before any distributed training starts - - # download stuff - # save to disk - # etc... - ... - - def train_dataloader(self): - # data transforms - # dataset creation - # return a DataLoader - ... - -Optimizer -^^^^^^^^^ - -Next we choose what optimizer to use for training our system. -In PyTorch we do it as follows: - -.. code-block:: python - - from torch.optim import Adam - optimizer = Adam(LitMNIST().parameters(), lr=1e-3) - - -In Lightning we do the same but organize it under the configure_optimizers method. - -.. testcode:: - - class LitMNIST(LightningModule): - - def configure_optimizers(self): - return Adam(self.parameters(), lr=1e-3) - -.. note:: The LightningModule itself has the parameters, so pass in self.parameters() - -However, if you have multiple optimizers use the matching parameters - -.. testcode:: - - class LitMNIST(LightningModule): - - def configure_optimizers(self): - return Adam(self.generator(), lr=1e-3), Adam(self.discriminator(), lr=1e-3) - -Training step -^^^^^^^^^^^^^ - -The training step is what happens inside the training loop. - -.. code-block:: python - - for epoch in epochs: - for batch in data: - # TRAINING STEP - # .... - # TRAINING STEP - loss.backward() - optimizer.step() - optimizer.zero_grad() - -In the case of MNIST we do the following - -.. code-block:: python - - for epoch in epochs: - for batch in data: - # TRAINING STEP START - x, y = batch - logits = model(x) - loss = F.nll_loss(logits, y) - # TRAINING STEP END - - loss.backward() - optimizer.step() - optimizer.zero_grad() - -In Lightning, everything that is in the training step gets organized under the `training_step` function -in the LightningModule - -.. testcode:: - - class LitMNIST(LightningModule): - - def training_step(self, batch, batch_idx): - x, y = batch - logits = self(x) - loss = F.nll_loss(logits, y) - return {'loss': loss} - # return loss (also works) - -Again, this is the same PyTorch code except that it has been organized by the LightningModule. -This code is not restricted which means it can be as complicated as a full seq-2-seq, RL loop, GAN, etc... - ---------- - -Training --------- -So far we defined 4 key ingredients in pure PyTorch but organized the code inside the LightningModule. - -1. Model. -2. Training data. -3. Optimizer. -4. What happens in the training loop. - -For clarity, we'll recall that the full LightningModule now looks like this. - -.. testcode:: - - class LitMNIST(LightningModule): - def __init__(self): - super().__init__() - self.layer_1 = torch.nn.Linear(28 * 28, 128) - self.layer_2 = torch.nn.Linear(128, 256) - self.layer_3 = torch.nn.Linear(256, 10) - - def forward(self, x): - batch_size, channels, width, height = x.size() - x = x.view(batch_size, -1) - x = self.layer_1(x) - x = torch.relu(x) - x = self.layer_2(x) - x = torch.relu(x) - x = self.layer_3(x) - x = torch.log_softmax(x, dim=1) - return x - - def train_dataloader(self): - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform) - return DataLoader(mnist_train, batch_size=64) - - def configure_optimizers(self): - return Adam(self.parameters(), lr=1e-3) - - def training_step(self, batch, batch_idx): - x, y = batch - logits = self(x) - loss = F.nll_loss(logits, y) - - # add logging - logs = {'loss': loss} - return {'loss': loss, 'log': logs} - -Again, this is the same PyTorch code, except that it's organized -by the LightningModule. This organization now lets us train this model - -Train on CPU -^^^^^^^^^^^^ - -.. code-block:: python - - from pytorch_lightning import Trainer - - model = LitMNIST() - trainer = Trainer() - trainer.fit(model) - -You should see the following weights summary and progress bar - -.. figure:: /_images/mnist_imgs/mnist_cpu_bar.png - :alt: mnist CPU bar - -Logging -^^^^^^^ - -When we added the `log` key in the return dictionary it went into the built in tensorboard logger. -But you could have also logged by calling: - -.. code-block:: python - - def training_step(self, batch, batch_idx): - # ... - loss = ... - self.logger.summary.scalar('loss', loss) - -Which will generate automatic tensorboard logs. - -.. figure:: /_images/mnist_imgs/mnist_tb.png - :alt: mnist CPU bar - -But you can also use any of the `number of other loggers `_ we support. - -GPU training -^^^^^^^^^^^^ - -But the beauty is all the magic you can do with the trainer flags. For instance, to run this model on a GPU: - -.. code-block:: python - - model = LitMNIST() - trainer = Trainer(gpus=1) - trainer.fit(model) - - -.. figure:: /_images/mnist_imgs/mnist_gpu.png - :alt: mnist GPU bar - -Multi-GPU training -^^^^^^^^^^^^^^^^^^ - -Or you can also train on multiple GPUs. - -.. code-block:: python - - model = LitMNIST() - trainer = Trainer(gpus=8) - trainer.fit(model) - -Or multiple nodes - -.. code-block:: python - - # (32 GPUs) - model = LitMNIST() - trainer = Trainer(gpus=8, num_nodes=4, distributed_backend='ddp') - trainer.fit(model) - -Refer to the `distributed computing guide for more details `_. - -TPUs -^^^^ -Did you know you can use PyTorch on TPUs? It's very hard to do, but we've -worked with the xla team to use their awesome library to get this to work -out of the box! - -Let's train on Colab (`full demo available here `_) - -First, change the runtime to TPU (and reinstall lightning). - -.. figure:: /_images/mnist_imgs/runtime_tpu.png - :alt: mnist GPU bar - -.. figure:: /_images/mnist_imgs/restart_runtime.png - :alt: mnist GPU bar - -Next, install the required xla library (adds support for PyTorch on TPUs) - -.. code-block:: python - - import collections - from datetime import datetime, timedelta - import os - import requests - import threading - - _VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server') - VERSION = "torch_xla==nightly" #@param ["xrt==1.15.0", "torch_xla==nightly"] - CONFIG = { - 'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'), - 'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format( - (datetime.today() - timedelta(1)).strftime('%Y%m%d'))), - }[VERSION] - DIST_BUCKET = 'gs://tpu-pytorch/wheels' - TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) - TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) - TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) - - # Update TPU XRT version - def update_server_xrt(): - print('Updating server-side XRT to {} ...'.format(CONFIG.server)) - url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format( - TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0], - XRT_VERSION=CONFIG.server, - ) - print('Done updating server-side XRT: {}'.format(requests.post(url))) - - update = threading.Thread(target=update_server_xrt) - update.start() - -.. code-block:: - - # Install Colab TPU compat PyTorch/TPU wheels and dependencies - !pip uninstall -y torch torchvision - !gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" . - !gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" . - !gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" . - !pip install "$TORCH_WHEEL" - !pip install "$TORCH_XLA_WHEEL" - !pip install "$TORCHVISION_WHEEL" - !sudo apt-get install libomp5 - update.join() - -In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy -of this program. This means that without taking any care you will download the dataset N times which -will cause all sorts of issues. - -To solve this problem, move the download code to the `prepare_data` method in the LightningModule. -In this method we do all the preparation we need to do once (instead of on every gpu). - -.. testcode:: - - class LitMNIST(LightningModule): - def prepare_data(self): - # transform - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) - - # download - mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) - mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform) - - # train/val split - mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - - # assign to use in dataloaders - self.train_dataset = mnist_train - self.val_dataset = mnist_val - self.test_dataset = mnist_test - - def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=64) - - def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=64) - - def test_dataloader(self): - return DataLoader(self.test_dataset, batch_size=64) - -The `prepare_data` method is also a good place to do any data processing that needs to be done only -once (ie: download or tokenize, etc...). - -.. note:: Lightning inserts the correct DistributedSampler for distributed training. No need to add yourself! - -Now we can train the LightningModule on a TPU without doing anything else! - -.. code-block:: python - - model = LitMNIST() - trainer = Trainer(num_tpu_cores=8) - trainer.fit(model) - -You'll now see the TPU cores booting up. - -.. figure:: /_images/mnist_imgs/tpu_start.png - :alt: TPU start - -Notice the epoch is MUCH faster! - -.. figure:: /_images/mnist_imgs/tpu_fast.png - :alt: TPU speed - ---------- - -.. include:: hyperparameters.rst - ---------- - -Validating ----------- - -For most cases, we stop training the model when the performance on a validation -split of the data reaches a minimum. - -Just like the `training_step`, we can define a `validation_step` to check whatever -metrics we care about, generate samples or add more to our logs. - -.. code-block:: python - - for epoch in epochs: - for batch in data: - # ... - # train - - # validate - outputs = [] - for batch in val_data: - x, y = batch # validation_step - y_hat = model(x) # validation_step - loss = loss(y_hat, x) # validation_step - outputs.append({'val_loss': loss}) # validation_step - - full_loss = outputs.mean() # validation_epoch_end - -Since the `validation_step` processes a single batch, -in Lightning we also have a `validation_epoch_end` method which allows you to compute -statistics on the full dataset after an epoch of validation data and not just the batch. - -In addition, we define a `val_dataloader` method which tells the trainer what data to use for validation. -Notice we split the train split of MNIST into train, validation. We also have to make sure to do the -sample split in the `train_dataloader` method. - -.. testcode:: - - class LitMNIST(LightningModule): - def validation_step(self, batch, batch_idx): - x, y = batch - logits = self(x) - loss = F.nll_loss(logits, y) - return {'val_loss': loss} - - def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - tensorboard_logs = {'val_loss': avg_loss} - return {'val_loss': avg_loss, 'log': tensorboard_logs} - - def val_dataloader(self): - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) - mnist_train = MNIST(os.getcwd(), train=True, download=False, - transform=transform) - _, mnist_val = random_split(mnist_train, [55000, 5000]) - mnist_val = DataLoader(mnist_val, batch_size=64) - return mnist_val - -Again, we've just organized the regular PyTorch code into two steps, the `validation_step` method which -operates on a single batch and the `validation_epoch_end` method to compute statistics on all batches. - -If you have these methods defined, Lightning will call them automatically. Now we can train -while checking the validation set. - -.. code-block:: python - - from pytorch_lightning import Trainer - - model = LitMNIST() - trainer = Trainer(num_tpu_cores=8) - trainer.fit(model) - -You may have noticed the words `Validation sanity check` logged. This is because Lightning runs 5 batches -of validation before starting to train. This is a kind of unit test to make sure that if you have a bug -in the validation loop, you won't need to potentially wait a full epoch to find out. - -.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for validation. - ---------- - -Testing -------- -Once our research is done and we're about to publish or deploy a model, we normally want to figure out -how it will generalize in the "real world." For this, we use a held-out split of the data for testing. - -Just like the validation loop, we define exactly the same steps for testing: - -- test_step -- test_epoch_end -- test_dataloader - -.. testcode:: - - class LitMNIST(LightningModule): - def test_step(self, batch, batch_idx): - x, y = batch - logits = self(x) - loss = F.nll_loss(logits, y) - return {'val_loss': loss} - - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - tensorboard_logs = {'val_loss': avg_loss} - return {'val_loss': avg_loss, 'log': tensorboard_logs} - - def test_dataloader(self): - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) - mnist_train = MNIST(os.getcwd(), train=False, download=False, transform=transform) - _, mnist_val = random_split(mnist_train, [55000, 5000]) - mnist_val = DataLoader(mnist_val, batch_size=64) - return mnist_val - -However, to make sure the test set isn't used inadvertently, Lightning has a separate API to run tests. -Once you train your model simply call `.test()`. - -.. code-block:: python - - from pytorch_lightning import Trainer - - model = LitMNIST() - trainer = Trainer(num_tpu_cores=8) - trainer.fit(model) - - # run test set - trainer.test() - -.. rst-class:: sphx-glr-script-out - - Out: - - .. code-block:: none - - -------------------------------------------------------------- - TEST RESULTS - {'test_loss': tensor(1.1703, device='cuda:0')} - -------------------------------------------------------------- - -You can also run the test from a saved lightning model - -.. code-block:: python - - model = LitMNIST.load_from_checkpoint(PATH) - trainer = Trainer(num_tpu_cores=8) - trainer.test(model) - -.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for testing. - -.. warning:: .test() is not stable yet on TPUs. We're working on getting around the multiprocessing challenges. - ---------- - -Predicting ----------- -Again, a LightningModule is exactly the same as a PyTorch module. This means you can load it -and use it for prediction. - -.. code-block:: python - - model = LitMNIST.load_from_checkpoint(PATH) - x = torch.Tensor(1, 1, 28, 28) - out = model(x) - -On the surface, it looks like `forward` and `training_step` are similar. Generally, we want to make sure that -what we want the model to do is what happens in the `forward`. whereas the `training_step` likely calls forward from -within it. - -.. testcode:: - - class MNISTClassifier(LightningModule): - - def forward(self, x): - batch_size, channels, width, height = x.size() - x = x.view(batch_size, -1) - x = self.layer_1(x) - x = torch.relu(x) - x = self.layer_2(x) - x = torch.relu(x) - x = self.layer_3(x) - x = torch.log_softmax(x, dim=1) - return x - - def training_step(self, batch, batch_idx): - x, y = batch - logits = self(x) - loss = F.nll_loss(logits, y) - return loss - -.. code-block:: python - - model = MNISTClassifier() - x = mnist_image() - logits = model(x) - -In this case, we've set this LightningModel to predict logits. But we could also have it predict feature maps: - -.. testcode:: - - class MNISTRepresentator(LightningModule): - - def forward(self, x): - batch_size, channels, width, height = x.size() - x = x.view(batch_size, -1) - x = self.layer_1(x) - x1 = torch.relu(x) - x = self.layer_2(x1) - x2 = torch.relu(x) - x3 = self.layer_3(x2) - return [x, x1, x2, x3] - - def training_step(self, batch, batch_idx): - x, y = batch - out, l1_feats, l2_feats, l3_feats = self(x) - logits = torch.log_softmax(out, dim=1) - ce_loss = F.nll_loss(logits, y) - loss = perceptual_loss(l1_feats, l2_feats, l3_feats) + ce_loss - return loss - -.. code-block:: python - - model = MNISTRepresentator.load_from_checkpoint(PATH) - x = mnist_image() - feature_maps = model(x) - -Or maybe we have a model that we use to do generation - -.. testcode:: - - class LitMNISTDreamer(LightningModule): - - def forward(self, z): - imgs = self.decoder(z) - return imgs - - def training_step(self, batch, batch_idx): - x, y = batch - representation = self.encoder(x) - imgs = self(representation) - - loss = perceptual_loss(imgs, x) - return loss - -.. code-block:: python - - model = LitMNISTDreamer.load_from_checkpoint(PATH) - z = sample_noise() - generated_imgs = model(z) - -How you split up what goes in `forward` vs `training_step` depends on how you want to use this model for -prediction. - ---------- - -Extensibility -------------- -Although lightning makes everything super simple, it doesn't sacrifice any flexibility or control. -Lightning offers multiple ways of managing the training state. - -Training overrides -^^^^^^^^^^^^^^^^^^ - -Any part of the training, validation and testing loop can be modified. -For instance, if you wanted to do your own backward pass, you would override the -default implementation - -.. testcode:: - - def backward(self, use_amp, loss, optimizer): - if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - -With your own - -.. testcode:: - - class LitMNIST(LightningModule): - - def backward(self, use_amp, loss, optimizer): - # do a custom way of backward - loss.backward(retain_graph=True) - -Or if you wanted to initialize ddp in a different way than the default one - -.. testcode:: - - def configure_ddp(self, model, device_ids): - # Lightning DDP simply routes to test_step, val_step, etc... - model = LightningDistributedDataParallel( - model, - device_ids=device_ids, - find_unused_parameters=True - ) - return model - -you could do your own: - -.. testcode:: - - class LitMNIST(LightningModule): - - def configure_ddp(self, model, device_ids): - - model = Horovod(model) - # model = Ray(model) - return model - -Every single part of training is configurable this way. -For a full list look at `LightningModule `_. - ---------- - -Callbacks ---------- -Another way to add arbitrary functionality is to add a custom callback -for hooks that you might care about - -.. testcode:: - - from pytorch_lightning.callbacks import Callback - - class MyPrintingCallback(Callback): - - def on_init_start(self, trainer): - print('Starting to init trainer!') - - def on_init_end(self, trainer): - print('Trainer is init now') - - def on_train_end(self, trainer, pl_module): - print('do something when training ends') - -And pass the callbacks into the trainer - -.. testcode:: - - trainer = Trainer(callbacks=[MyPrintingCallback()]) - -.. testoutput:: - :hide: - - Starting to init trainer! - Trainer is init now - -.. note:: - See full list of 12+ hooks in the :ref:`callbacks`. - ---------- - -.. include:: child_modules.rst - ---------- - -.. include:: transfer_learning.rst diff --git a/docs/source/lightning-module.rst b/docs/source/lightning-module.rst deleted file mode 100644 index 3e329bec3a4e82..00000000000000 --- a/docs/source/lightning-module.rst +++ /dev/null @@ -1,12 +0,0 @@ -.. role:: hidden - :class: hidden-section - -LightningModule -=============== - -.. automodule:: pytorch_lightning.core - :noindex: - :exclude-members: - _abc_impl, - summarize, - diff --git a/docs/source/loggers.rst b/docs/source/loggers.rst deleted file mode 100644 index 67cfdcf8b428d1..00000000000000 --- a/docs/source/loggers.rst +++ /dev/null @@ -1,13 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Loggers -=========== -.. automodule:: pytorch_lightning.loggers - :noindex: - :exclude-members: - _abc_impl, - _save_model, - on_epoch_end, - on_train_end, - on_epoch_start, diff --git a/docs/source/lr_finder.rst b/docs/source/lr_finder.rst deleted file mode 100755 index 3da5456b6de8b0..00000000000000 --- a/docs/source/lr_finder.rst +++ /dev/null @@ -1,114 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.core.lightning import LightningModule - -Learning Rate Finder --------------------- - -For training deep neural networks, selecting a good learning rate is essential -for both better performance and faster convergence. Even optimizers such as -`Adam` that are self-adjusting the learning rate can benefit from more optimal -choices. - -To reduce the amount of guesswork concerning choosing a good initial learning -rate, a `learning rate finder` can be used. As described in this `paper `_ -a learning rate finder does a small run where the learning rate is increased -after each processed batch and the corresponding loss is logged. The result of -this is a `lr` vs. `loss` plot that can be used as guidence for choosing a optimal -initial lr. - -.. warning:: For the moment, this feature only works with models having a single optimizer. - -Using Lightnings build-in LR finder -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In the most basic use case, this feature can be enabled during trainer construction -with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the lr finder -will automatically be run before any training is done. The ``lr`` that is found -and used will be written to the console and logged together with all other -hyperparameters of the model. - -.. testcode:: - - # default, no automatic learning rate finder - trainer = Trainer(auto_lr_find=True) - -When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate. -In both cases, if the respective fields are not found, an error will be thrown. - -.. testcode:: - - class LitModel(LightningModule): - - def __init__(self, hparams): - self.hparams = hparams - - def configure_optimizers(self): - return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate) - - # finds learning rate automatically - # sets hparams.lr or hparams.learning_rate to that learning rate - trainer = Trainer(auto_lr_find=True) - -To use an arbitrary value set it in the parameter. - -.. testcode:: - - # to set to your own hparams.my_value - trainer = Trainer(auto_lr_find='my_value') - -Under the hood, when you call fit, this is what happens. - -1. Run learning rate finder. -2. Run actual fit. - -.. code-block:: python - - # when you call .fit() this happens - # 1. find learning rate - # 2. actually run fit - trainer.fit(model) - -If you want to inspect the results of the learning rate finder before doing any -actual training or just play around with the parameters of the algorithm, this -can be done by invoking the ``lr_find`` method of the trainer. A typical example -of this would look like - -.. code-block:: python - - model = MyModelClass(hparams) - trainer = Trainer() - - # Run learning rate finder - lr_finder = trainer.lr_find(model) - - # Results can be found in - lr_finder.results - - # Plot with - fig = lr_finder.plot(suggest=True) - fig.show() - - # Pick point based on plot, or get suggestion - new_lr = lr_finder.suggestion() - - # update hparams of the model - model.hparams.lr = new_lr - - # Fit model - trainer.fit(model) - -The figure produced by ``lr_finder.plot()`` should look something like the figure -below. It is recommended to not pick the learning rate that achives the lowest -loss, but instead something in the middle of the sharpest downward slope (red point). -This is the point returned py ``lr_finder.suggestion()``. - -.. figure:: /_images/trainer/lr_finder.png - -The parameters of the algorithm can be seen below. - -.. autoclass:: pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin - :members: lr_find - :noindex: - :exclude-members: _run_lr_finder_internally, save_checkpoint, restore diff --git a/docs/source/multi_gpu.rst b/docs/source/multi_gpu.rst deleted file mode 100644 index 8688cd338bc1b4..00000000000000 --- a/docs/source/multi_gpu.rst +++ /dev/null @@ -1,376 +0,0 @@ -.. testsetup:: * - - import torch - from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.core.lightning import LightningModule - -.. _multi-gpu-training: - -Multi-GPU training -================== -Lightning supports multiple ways of doing distributed training. - -Preparing your code -------------------- -To train on CPU/GPU/TPU without changing your code, we need to build a few good habits :) - -Delete .cuda() or .to() calls -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Delete any calls to .cuda() or .to(device). - -.. testcode:: - - # before lightning - def forward(self, x): - x = x.cuda(0) - layer_1.cuda(0) - x_hat = layer_1(x) - - # after lightning - def forward(self, x): - x_hat = layer_1(x) - -Init using type_as -^^^^^^^^^^^^^^^^^^ -When you need to create a new tensor, use `type_as`. -This will make your code scale to any arbitrary number of GPUs or TPUs with Lightning - -.. testcode:: - - # before lightning - def forward(self, x): - z = torch.Tensor(2, 3) - z = z.cuda(0) - - # with lightning - def forward(self, x): - z = torch.Tensor(2, 3) - z = z.type_as(x) - -Remove samplers -^^^^^^^^^^^^^^^ -For multi-node or TPU training, in PyTorch we must use `torch.nn.DistributedSampler`. The -sampler makes sure each GPU sees the appropriate part of your data. - -.. testcode:: - - # without lightning - def train_dataloader(self): - dataset = MNIST(...) - sampler = None - - if self.on_tpu: - sampler = DistributedSampler(dataset) - - return DataLoader(dataset, sampler=sampler) - -With Lightning, you don't need to do this because it takes care of adding the correct samplers -when needed. - -.. testcode:: - - # with lightning - def train_dataloader(self): - dataset = MNIST(...) - return DataLoader(dataset) - -.. note:: If you don't want this behavior, disable it with `Trainer(replace_sampler_ddp=False)` - -.. note:: For iterable datasets, we don't do this automatically. - -Make Model Picklable -^^^^^^^^^^^^^^^^^^^^ -It's very likely your code is already `picklable `_, -so you don't have to do anything to make this change. -However, if you run distributed and see an error like this: - -.. code-block:: - - self._launch(process_obj) - File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47, - in _launch reduction.dump(process_obj, fp) - File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/reduction.py", line 60, in dump - ForkingPickler(file, protocol).dump(obj) - _pickle.PicklingError: Can't pickle at 0x2b599e088ae8>: - attribute lookup on __main__ failed - -This means you have something in your model definition, transforms, optimizer, dataloader or callbacks -that is cannot be pickled. By pickled we mean the following would fail. - -.. code-block:: python - - import pickle - pickle.dump(some_object) - -This is a limitation of using multiple processes for distributed training within PyTorch. -To fix this issue, find your piece of code that cannot be pickled. The end of the stacktrace -is usually helpful. - -.. code-block:: - - self._launch(process_obj) - File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 47, - in _launch reduction.dump(process_obj, fp) - File "/net/software/local/python/3.6.5/lib/python3.6/multiprocessing/reduction.py", line 60, in dump - ForkingPickler(file, protocol).dump(obj) - _pickle.PicklingError: Can't pickle [THIS IS THE THING TO FIND AND DELETE]: - attribute lookup on __main__ failed - -ie: in the stacktrace example here, there seems to be a lambda function somewhere in the user code -which cannot be pickled. - -Distributed modes ------------------ -Lightning allows multiple ways of training - -- Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine) -- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines). -- DistributedDataParallel2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines). -- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime) -- TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod) - -Data Parallel (dp) -^^^^^^^^^^^^^^^^^^ -`DataParallel `_ splits a batch across k GPUs. That is, if you have a batch of 32 and use dp with 2 gpus, -each GPU will process 16 samples, after which the root node will aggregate the results. - -.. warning:: DP use is discouraged by PyTorch and Lightning. Use ddp which is more stable and at least 3x faster - -.. testcode:: - :skipif: torch.cuda.device_count() < 2 - - # train on 2 GPUs (using dp mode) - trainer = Trainer(gpus=2, distributed_backend='dp') - -Distributed Data Parallel -^^^^^^^^^^^^^^^^^^^^^^^^^ -`DistributedDataParallel `_ works as follows. - -1. Each GPU across every node gets its own process. - -2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset. - -3. Each process inits the model. - -.. note:: Make sure to set the random seed so that each model inits with the same weights - -4. Each process performs a full forward and backward pass in parallel. - -5. The gradients are synced and averaged across all processes. - -6. Each process updates its optimizer. - -.. code-block:: python - - # train on 8 GPUs (same machine (ie: node)) - trainer = Trainer(gpus=8, distributed_backend='ddp') - - # train on 32 GPUs (4 nodes) - trainer = Trainer(gpus=8, distributed_backend='ddp', num_nodes=4) - -Distributed Data Parallel 2 -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In certain cases, it's advantageous to use all batches on the same machine instead of a subset. -For instance you might want to compute a NCE loss where it pays to have more negative samples. - -In this case, we can use ddp2 which behaves like dp in a machine and ddp across nodes. DDP2 does the following: - -1. Copies a subset of the data to each node. - -2. Inits a model on each node. - -3. Runs a forward and backward pass using DP. - -4. Syncs gradients across nodes. - -5. Applies the optimizer updates. - -.. code-block:: python - - # train on 32 GPUs (4 nodes) - trainer = Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4) - -Horovod -^^^^^^^ -`Horovod `_ allows the same training script to be used for single-GPU, -multi-GPU, and multi-node training. - -Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed -subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass, -then synchronously applied before beginning the next step. - -The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In -the training script, Horovod will detect the number of workers from the environment, and automatically -scale the learning rate to compensate for the increased total batch size. - -Horovod can be configured in the training script to run with any number of GPUs / processes as follows: - -.. code-block:: python - - # train Horovod on GPU (number of GPUs / machines provided on command-line) - trainer = Trainer(distributed_backend='horovod', gpus=1) - - # train Horovod on CPU (number of processes / machines provided on command-line) - trainer = Trainer(distributed_backend='horovod') - -When starting the training job, the driver application will then be used to specify the total -number of worker processes: - -.. code-block:: bash - - # run training with 4 GPUs on a single machine - horovodrun -np 4 python train.py - - # run training with 8 GPUs on two machines (4 GPUs each) - horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py - -See the official `Horovod documentation `_ for details -on installation and performance tuning. - -DP/DDP2 caveats -^^^^^^^^^^^^^^^ -In DP and DDP2 each GPU within a machine sees a portion of a batch. -DP and ddp2 roughly do the following: - -.. testcode:: - - def distributed_forward(batch, model): - batch = torch.Tensor(32, 8) - gpu_0_batch = batch[:8] - gpu_1_batch = batch[8:16] - gpu_2_batch = batch[16:24] - gpu_3_batch = batch[24:] - - y_0 = model_copy_gpu_0(gpu_0_batch) - y_1 = model_copy_gpu_1(gpu_1_batch) - y_2 = model_copy_gpu_2(gpu_2_batch) - y_3 = model_copy_gpu_3(gpu_3_batch) - - return [y_0, y_1, y_2, y_3] - -So, when Lightning calls any of the `training_step`, `validation_step`, `test_step` -you will only be operating on one of those pieces. - -.. testcode:: - - # the batch here is a portion of the FULL batch - def training_step(self, batch, batch_idx): - y_0 = batch - -For most metrics, this doesn't really matter. However, if you want -to add something to your computational graph (like softmax) -using all batch parts you can use the `training_step_end` step. - -.. testcode:: - - def training_step_end(self, outputs): - # only use when on dp - outputs = torch.cat(outputs, dim=1) - softmax = softmax(outputs, dim=1) - out = softmax.mean() - return out - -In pseudocode, the full sequence is: - -.. code-block:: python - - # get data - batch = next(dataloader) - - # copy model and data to each gpu - batch_splits = split_batch(batch, num_gpus) - models = copy_model_to_gpus(model) - - # in parallel, operate on each batch chunk - all_results = [] - for gpu_num in gpus: - batch_split = batch_splits[gpu_num] - gpu_model = models[gpu_num] - out = gpu_model(batch_split) - all_results.append(out) - - # use the full batch for something like softmax - full out = model.training_step_end(all_results) - -to illustrate why this is needed, let's look at dataparallel - -.. testcode:: - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(batch) - - # on dp or ddp2 if we did softmax now it would be wrong - # because batch is actually a piece of the full batch - return y_hat - - def training_step_end(self, batch_parts_outputs): - # batch_parts_outputs has outputs of each part of the batch - - # do softmax here - outputs = torch.cat(outputs, dim=1) - softmax = softmax(outputs, dim=1) - out = softmax.mean() - - return out - -If `training_step_end` is defined it will be called regardless of tpu, dp, ddp, etc... which means -it will behave the same no matter the backend. - -Validation and test step also have the same option when using dp - -.. testcode:: - - def validation_step_end(self, batch_parts_outputs): - ... - - def test_step_end(self, batch_parts_outputs): - ... - -Implement Your Own Distributed (DDP) training -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you need your own way to init PyTorch DDP you can override :meth:`pytorch_lightning.core.LightningModule.`. - -If you also need to use your own DDP implementation, override: :meth:`pytorch_lightning.core.LightningModule.configure_ddp`. - - -Batch size ----------- -When using distributed training make sure to modify your learning rate according to your effective -batch size. - -Let's say you have a batch size of 7 in your dataloader. - -.. testcode:: - - class LitModel(LightningModule): - - def train_dataloader(self): - return Dataset(..., batch_size=7) - -In (DDP, Horovod) your effective batch size will be 7 * gpus * num_nodes. - -.. code-block:: python - - # effective batch size = 7 * 8 - Trainer(gpus=8, distributed_backend='ddp|horovod') - - # effective batch size = 7 * 8 * 10 - Trainer(gpus=8, num_nodes=10, distributed_backend='ddp|horovod') - - -In DDP2, your effective batch size will be 7 * num_nodes. -The reason is that the full batch is visible to all GPUs on the node when using DDP2. - -.. code-block:: python - - # effective batch size = 7 - Trainer(gpus=8, distributed_backend='ddp2') - - # effective batch size = 7 * 10 - Trainer(gpus=8, num_nodes=10, distributed_backend='ddp2') - - -.. note:: Huge batch sizes are actually really bad for convergence. Check out: - `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour `_ diff --git a/docs/source/multiple_loaders.rst b/docs/source/multiple_loaders.rst deleted file mode 100644 index dca339f9b99ad2..00000000000000 --- a/docs/source/multiple_loaders.rst +++ /dev/null @@ -1,73 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.core.lightning import LightningModule - -Multiple Datasets -================= -Lightning supports multiple dataloaders in a few ways. - -1. Create a dataloader that iterates both datasets under the hood. -2. In the validation and test loop you also have the option to return multiple dataloaders - which lightning will call sequentially. - -Multiple training dataloaders ------------------------------ -For training, the best way to use multiple-dataloaders is to create a Dataloader class -which wraps both your dataloaders. (This of course also works for testing and validation -dataloaders). - -(`reference `_) - -.. testcode:: - - class ConcatDataset(torch.utils.data.Dataset): - def __init__(self, *datasets): - self.datasets = datasets - - def __getitem__(self, i): - return tuple(d[i] for d in self.datasets) - - def __len__(self): - return min(len(d) for d in self.datasets) - - class LitModel(LightningModule): - - def train_dataloader(self): - concat_dataset = ConcatDataset( - datasets.ImageFolder(traindir_A), - datasets.ImageFolder(traindir_B) - ) - - loader = torch.utils.data.DataLoader( - concat_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.workers, - pin_memory=True - ) - return loader - - def val_dataloader(self): - # SAME - ... - - def test_dataloader(self): - # SAME - ... - -Test/Val dataloaders --------------------- -For validation, test dataloaders lightning also gives you the additional -option of passing in multiple dataloaders back from each call. - -See the following for more details: - -- :meth:`~pytorch_lightning.core.LightningModule.val_dataloader` -- :meth:`~pytorch_lightning.core.LightningModule.test_dataloader` - -.. testcode:: - - def val_dataloader(self): - loader_1 = Dataloader() - loader_2 = Dataloader() - return [loader_1, loader_2] diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst deleted file mode 100644 index 24b11412e5c7d9..00000000000000 --- a/docs/source/new-project.rst +++ /dev/null @@ -1,279 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.core.lightning import LightningModule - from pytorch_lightning.trainer.trainer import Trainer - - - -Quick Start -=========== - -PyTorch Lightning is nothing more than organized PyTorch code. -Once you've organized it into a LightningModule, it automates most of the training for you. - -To illustrate, here's the typical PyTorch project structure organized in a LightningModule. - -.. figure:: /_images/mnist_imgs/pt_to_pl.jpg - :alt: Convert from PyTorch to Lightning - - -Step 1: Define a LightningModule ---------------------------------- - -.. testcode:: - :skipif: not TORCHVISION_AVAILABLE - - import os - - import torch - from torch.nn import functional as F - from torch.utils.data import DataLoader - from torchvision.datasets import MNIST - from torchvision import transforms - from pytorch_lightning.core.lightning import LightningModule - - class LitModel(LightningModule): - - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(28 * 28, 10) - - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - tensorboard_logs = {'train_loss': loss} - return {'loss': loss, 'log': tensorboard_logs} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.001) - - def train_dataloader(self): - dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True) - return loader - - -Step 2: Fit with a Trainer --------------------------- - -.. testcode:: - :skipif: torch.cuda.device_count() < 8 - - from pytorch_lightning import Trainer - - model = LitModel() - - # most basic trainer, uses good defaults - trainer = Trainer(gpus=8, num_nodes=1) - trainer.fit(model) - -Under the hood, lightning does (in high-level pseudocode): - -.. code-block:: python - - model = LitModel() - train_dataloader = model.train_dataloader() - optimizer = model.configure_optimizers() - - for epoch in epochs: - train_outs = [] - for batch in train_dataloader: - loss = model.training_step(batch) - loss.backward() - train_outs.append(loss.detach()) - - optimizer.step() - optimizer.zero_grad() - - # optional for logging, etc... - model.training_epoch_end(train_outs) - -Validation loop ---------------- -To also add a validation loop add the following functions - -.. testcode:: - - class LitModel(LightningModule): - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - return {'val_loss': F.cross_entropy(y_hat, y)} - - def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - tensorboard_logs = {'val_loss': avg_loss} - return {'val_loss': avg_loss, 'log': tensorboard_logs} - - def val_dataloader(self): - # TODO: do a real train/val split - dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - loader = DataLoader(dataset, batch_size=32, num_workers=4) - return loader - -And now the trainer will call the validation loop automatically - -.. code-block:: python - - # most basic trainer, uses good defaults - trainer = Trainer(gpus=8, num_nodes=1) - trainer.fit(model) - -Under the hood in pseudocode, lightning does the following: - -.. testsetup:: * - - train_dataloader = [] - -.. testcode:: - - # ... - for batch in train_dataloader: - loss = model.training_step() - loss.backward() - # ... - - if validate_at_some_point: - model.eval() - val_outs = [] - for val_batch in model.val_dataloader: - val_out = model.validation_step(val_batch) - val_outs.append(val_out) - - model.validation_epoch_end(val_outs) - model.train() - -The beauty of Lightning is that it handles the details of when to validate, when to call .eval(), -turning off gradients, detaching graphs, making sure you don't enable shuffle for val, etc... - -.. note:: Lightning removes all the million details you need to remember during research - -Test loop ---------- -You might also need a test loop - -.. testcode:: - - class LitModel(LightningModule): - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - return {'test_loss': F.cross_entropy(y_hat, y)} - - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() - tensorboard_logs = {'test_loss': avg_loss} - return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} - - def test_dataloader(self): - # TODO: do a real train/val split - dataset = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - loader = DataLoader(dataset, batch_size=32, num_workers=4) - return loader - -However, this time you need to specifically call test (this is done so you don't use the test set by mistake) - -.. code-block:: python - - # OPTION 1: - # test after fit - trainer.fit(model) - trainer.test() - - # OPTION 2: - # test after loading weights - model = LitModel.load_from_checkpoint(PATH) - trainer = Trainer(num_tpu_cores=1) - trainer.test() - -Again, under the hood, lightning does the following in (pseudocode): - -.. code-block:: python - - model.eval() - test_outs = [] - for test_batch in model.test_dataloader: - test_out = model.test_step(val_batch) - test_outs.append(test_out) - - model.test_epoch_end(test_outs) - -Datasets --------- -If you don't want to define the datasets as part of the LightningModule, just pass them into fit instead. - -.. code-block:: python - - # pass in datasets if you want. - train_dataloader = DataLoader(dataset, batch_size=32, num_workers=4) - val_dataloader, test_dataloader = ... - - trainer = Trainer(gpus=8, num_nodes=1) - trainer.fit(model, train_dataloader, val_dataloader) - - trainer.test(test_dataloader=test_dataloader) - -The advantage of this method is the ability to reuse models for different datasets. The disadvantage -is that for research it makes readability and reproducibility more difficult. This is why we recommend -to define the datasets in the LightningModule if you're doing research, but use the method above for -production models or for prediction tasks. - -Why do you need Lightning? --------------------------- -Notice the code above has nothing about .cuda() or 16-bit or early stopping or logging, etc... -This is where Lightning adds a ton of value. - -Without changing a SINGLE line of your code, you can now do the following with the above code - -.. code-block:: python - - # train on TPUs using 16 bit precision with early stopping - # using only half the training data and checking validation every quarter of a training epoch - trainer = Trainer( - nb_tpu_cores=8, - precision=16, - early_stop_checkpoint=True, - train_percent_check=0.5, - val_check_interval=0.25 - ) - - # train on 256 GPUs - trainer = Trainer( - gpus=8, - num_nodes=32 - ) - - # train on 1024 CPUs across 128 machines - trainer = Trainer( - num_processes=8, - num_nodes=128 - ) - -And the best part is that your code is STILL just PyTorch... meaning you can do anything you -would normally do. - -.. code-block:: python - - model = LitModel() - model.eval() - - y_hat = model(x) - - model.anything_you_can_do_with_pytorch() - -Summary -------- -In short, by refactoring your PyTorch code: - -1. You STILL keep pure PyTorch. -2. You DON't lose any flexibility. -3. You can get rid of all of your boilerplate. -4. You make your code generalizable to any hardware. -5. Your code is now readable and easier to reproduce (ie: you help with the reproducibility crisis). -6. Your LightningModule is still just a pure PyTorch module. diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst deleted file mode 100644 index 8f8715a09e7b3a..00000000000000 --- a/docs/source/optimizers.rst +++ /dev/null @@ -1,119 +0,0 @@ -Optimization -=============== - -Learning rate scheduling -------------------------------------- -Every optimizer you use can be paired with any `LearningRateScheduler `_. - -.. testcode:: - - # no LR scheduler - def configure_optimizers(self): - return Adam(...) - - # Adam + LR scheduler - def configure_optimizers(self): - optimizer = Adam(...) - scheduler = ReduceLROnPlateau(optimizer, ...) - return [optimizer], [scheduler] - - # Two optimziers each with a scheduler - def configure_optimizers(self): - optimizer1 = Adam(...) - optimizer2 = SGD(...) - scheduler1 = ReduceLROnPlateau(optimizer1, ...) - scheduler2 = LambdaLR(optimizer2, ...) - return [optimizer1, optimizer2], [scheduler1, scheduler2] - - # Same as above with additional params passed to the first scheduler - def configure_optimizers(self): - optimizers = [Adam(...), SGD(...)] - schedulers = [ - { - 'scheduler': ReduceLROnPlateau(optimizers[0], ...), - 'monitor': 'val_recall', # Default: val_loss - 'interval': 'epoch', - 'frequency': 1 - }, - LambdaLR(optimizers[1], ...) - ] - return optimizers, schedulers - - -Use multiple optimizers (like GANs) -------------------------------------- -To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers` - -.. testcode:: - - # one optimizer - def configure_optimizers(self): - return Adam(...) - - # two optimizers, no schedulers - def configure_optimizers(self): - return Adam(...), SGD(...) - - # Two optimizers, one scheduler for adam only - def configure_optimizers(self): - return [Adam(...), SGD(...)], [ReduceLROnPlateau()] - -Lightning will call each optimizer sequentially: - -.. code-block:: python - - for epoch in epochs: - for batch in data: - for opt in optimizers: - train_step(opt) - opt.step() - - for scheduler in scheduler: - scheduler.step() - - -Step optimizers at arbitrary intervals ----------------------------------------- -To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling, -override the :meth:`optimizer_step` function. - -For example, here step optimizer A every 2 batches and optimizer B every 4 batches - -.. testcode:: - - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None): - optimizer.step() - optimizer.zero_grad() - - # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None): - # update generator opt every 2 steps - if optimizer_i == 0: - if batch_nb % 2 == 0 : - optimizer.step() - optimizer.zero_grad() - - # update discriminator opt every 4 steps - if optimizer_i == 1: - if batch_nb % 4 == 0 : - optimizer.step() - optimizer.zero_grad() - - # ... - # add as many optimizers as you want - -Here we add a learning-rate warm up - -.. testcode:: - - # learning rate warm-up - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None): - # warm up lr - if self.trainer.global_step < 500: - lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) - for pg in optimizer.param_groups: - pg['lr'] = lr_scale * self.hparams.learning_rate - - # update params - optimizer.step() - optimizer.zero_grad() diff --git a/docs/source/slurm.rst b/docs/source/slurm.rst deleted file mode 100644 index ed09e7509b5712..00000000000000 --- a/docs/source/slurm.rst +++ /dev/null @@ -1,111 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - -Computing cluster (SLURM) -========================= - -Lightning automates job the details behind training on a SLURM powered cluster. - -.. _multi-node: - -Multi-node training -------------------- -To train a model using multiple-nodes do the following: - -1. Design your LightningModule. - -2. Enable ddp in the trainer - - .. code-block:: python - - # train on 32 GPUs across 4 nodes - trainer = Trainer(gpus=8, num_nodes=4, distributed_backend='ddp') - -3. It's a good idea to structure your train.py file like this: - - .. testcode:: - - # train.py - def main(hparams): - model = LightningTemplateModel(hparams) - - trainer = pl.Trainer( - gpus=8, - num_nodes=4, - distributed_backend='ddp' - ) - - trainer.fit(model) - - - if __name__ == '__main__': - root_dir = os.path.dirname(os.path.realpath(__file__)) - parent_parser = ArgumentParser(add_help=False) - hyperparams = parser.parse_args() - - # TRAIN - main(hyperparams) - -4. Create the appropriate SLURM job - - .. code-block:: bash - - # (submit.sh) - #!/bin/bash -l - - # SLURM SUBMIT SCRIPT - #SBATCH --nodes=4 - #SBATCH --gres=gpu:8 - #SBATCH --ntasks-per-node=8 - #SBATCH --mem=0 - #SBATCH --time=0-02:00:00 - - # activate conda env - source activate $1 - - # ------------------------- - # debugging flags (optional) - export NCCL_DEBUG=INFO - export PYTHONFAULTHANDLER=1 - - # on your cluster you might need these: - # set the network interface - # export NCCL_SOCKET_IFNAME=^docker0,lo - - # might need the latest cuda - # module load NCCL/2.4.7-1-cuda.10.0 - # ------------------------- - - # run script from above - srun python3 train.py - -5. If you want auto-resubmit (read below), add this line to the submit.sh script - - .. code-block:: bash - - #SBATCH --signal=SIGUSR1@90 - -6. Submit the SLURM job - - .. code-block:: bash - - sbatch submit.sh - -.. note:: using :class:`~torch.utils.data.distributed.DistributedSampler` is already handled by Lightning. - -Walltime auto-resubmit ----------------------- -When you use Lightning in a SLURM cluster, lightning automatically detects when it is about -to run into the walltime, and it does the following: - -1. Saves a temporary checkpoint. -2. Requeues the job. -3. When the job starts, it loads the temporary checkpoint. - -To get this behavior make sure to add the correct signal to your SLURM script - -.. code-block:: - - # 90 seconds before training ends - #SBATCH --signal=SIGUSR1@90 diff --git a/docs/source/starter/converting.rst b/docs/source/starter/converting.rst new file mode 100644 index 00000000000000..9077ba7f230833 --- /dev/null +++ b/docs/source/starter/converting.rst @@ -0,0 +1,118 @@ +.. testsetup:: * + + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule + from pytorch_lightning.trainer.trainer import Trainer + +.. _converting: + +************************************** +How to organize PyTorch into Lightning +************************************** + +To enable your code to work with Lightning, here's how to organize PyTorch into Lightning + +-------- + +1. Move your computational code +=============================== +Move the model architecture and forward pass to your :doc:`lightning module <../common/lightning_module>`. + +.. testcode:: + + class LitModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(28 * 28, 128) + self.layer_2 = nn.Linear(128, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.layer_1(x) + x = F.relu(x) + x = self.layer_2(x) + return x + +-------- + +2. Move the optimizer(s) and schedulers +======================================= +Move your optimizers to the :func:`~pytorch_lightning.core.LightningModule.configure_optimizers` hook. + +.. testcode:: + + class LitModel(LightningModule): + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + +-------- + +3. Find the train loop "meat" +============================= +Lightning automates most of the training for you, the epoch and batch iterations, all you need to keep is the training step logic. +This should go into the :func:`~pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case): + +.. testcode:: + + class LitModel(LightningModule): + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + +-------- + +4. Find the val loop "meat" +=========================== +To add an (optional) validation loop add logic to the +:func:`~pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case). + +.. testcode:: + + class LitModel(LightningModule): + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + val_loss = F.cross_entropy(y_hat, y) + return val_loss + +.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation + +-------- + +5. Find the test loop "meat" +============================ +To add an (optional) test loop add logic to the +:func:`~pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case). + +.. testcode:: + + class LitModel(LightningModule): + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + +.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing. + +The test loop will not be used until you call. + +.. code-block:: + + trainer.test() + +.. tip:: ``.test()`` loads the best checkpoint automatically + +-------- + +6. Remove any .cuda() or to.device() calls +========================================== +Your :doc:`lightning module <../common/lightning_module>` can automatically run on any hardware! diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst new file mode 100644 index 00000000000000..551b8182caa7d1 --- /dev/null +++ b/docs/source/starter/introduction_guide.rst @@ -0,0 +1,1144 @@ +.. testsetup:: * + + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule + from pytorch_lightning.trainer.trainer import Trainer + +.. _introduction_guide: + +######################### +Step-by-step walk-through +######################### +This guide will walk you through the core pieces of PyTorch Lightning. + +We'll accomplish the following: + +- Implement an MNIST classifier. +- Use inheritance to implement an AutoEncoder + +.. note:: Any DL/ML PyTorch project fits into the Lightning structure. Here we just focus on 3 types + of research to illustrate. + +-------------- + +************************** +From MNIST to AutoEncoders +************************** + + +Installing Lightning +==================== + + +Lightning is trivial to install. We recommend using conda environments + +.. code-block:: bash + + conda activate my_env + pip install pytorch-lightning + +Or without conda environments, use pip. + +.. code-block:: bash + + pip install pytorch-lightning + +Or conda. + +.. code-block:: bash + + conda install pytorch-lightning -c conda-forge + +------------- + +The research +============ + +The Model +--------- + +The :doc:`lightning module <../common/lightning_module>` holds all the core research ingredients: + +- The model + +- The optimizers + +- The train/ val/ test steps + +Let's first start with the model. In this case, we'll design a 3-layer neural network. + +.. testcode:: + + import torch + from torch.nn import functional as F + from torch import nn + from pytorch_lightning.core.lightning import LightningModule + + class LitMNIST(LightningModule): + + def __init__(self): + super().__init__() + + # mnist images are (1, 28, 28) (channels, width, height) + self.layer_1 = nn.Linear(28 * 28, 128) + self.layer_2 = nn.Linear(128, 256) + self.layer_3 = nn.Linear(256, 10) + + def forward(self, x): + batch_size, channels, width, height = x.size() + + # (b, 1, 28, 28) -> (b, 1*28*28) + x = x.view(batch_size, -1) + x = self.layer_1(x) + x = F.relu(x) + x = self.layer_2(x) + x = F.relu(x) + x = self.layer_3(x) + + x = F.log_softmax(x, dim=1) + return x + +Notice this is a :doc:`lightning module <../common/lightning_module>` instead of a ``torch.nn.Module``. A LightningModule is +equivalent to a pure PyTorch Module except it has added functionality. However, you can use it **EXACTLY** the same as you would a PyTorch Module. + +.. testcode:: + + net = LitMNIST() + x = torch.randn(1, 1, 28, 28) + out = net(x) + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: python + + torch.Size([1, 10]) + + +Now we add the training_step which has all our training loop logic + +.. testcode:: + + class LitMNIST(LightningModule): + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + return loss + +Data +---- + + +Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST. + +.. testcode:: + :skipif: not _TORCHVISION_AVAILABLE + + from torch.utils.data import DataLoader, random_split + from torchvision.datasets import MNIST + import os + from torchvision import datasets, transforms + + # transforms + # prepare transforms standard to MNIST + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))]) + + # data + mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) + mnist_train = DataLoader(mnist_train, batch_size=64) + +.. testoutput:: + :hide: + :skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not _TORCHVISION_AVAILABLE + + Downloading ... + Extracting ... + Downloading ... + Extracting ... + Downloading ... + Extracting ... + Processing... + Done! + +You can use DataLoaders in 3 ways: + +1. Pass DataLoaders to .fit() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Pass in the dataloaders to the `.fit()` function. + +.. code-block:: python + + model = LitMNIST() + trainer = Trainer() + trainer.fit(model, mnist_train) + + +2. LightningModule DataLoaders +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +For fast research prototyping, it might be easier to link the model with the dataloaders. + + +.. code-block:: python + + class LitMNIST(pl.LightningModule): + + def train_dataloader(self): + # transforms + # prepare transforms standard to MNIST + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))]) + # data + mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) + return DataLoader(mnist_train, batch_size=64) + + def val_dataloader(self): + transforms = ... + mnist_val = ... + return DataLoader(mnist_val, batch_size=64) + + def test_dataloader(self): + transforms = ... + mnist_test = ... + return DataLoader(mnist_test, batch_size=64) + +DataLoaders are already in the model, no need to specify on .fit(). + +.. code-block:: python + + model = LitMNIST() + trainer = Trainer() + trainer.fit(model) + +3. DataModules (recommended) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Defining free-floating dataloaders, splits, download instructions, and such can get messy. +In this case, it's better to group the full definition of a dataset into a `DataModule` which includes: + +- Download instructions +- Processing instructions +- Split instructions +- Train dataloader +- Val dataloader(s) +- Test dataloader(s) + +.. testcode:: + + class MyDataModule(LightningDataModule): + + def __init__(self): + super().__init__() + self.train_dims = None + self.vocab_size = 0 + + def prepare_data(self): + # called only on 1 GPU + download_dataset() + tokenize() + build_vocab() + + def setup(self, stage: Optional[str] = None): + # called on every GPU + vocab = load_vocab() + self.vocab_size = len(vocab) + + self.train, self.val, self.test = load_datasets() + self.train_dims = self.train.next_batch.size() + + def train_dataloader(self): + transforms = ... + return DataLoader(self.train, batch_size=64) + + def val_dataloader(self): + transforms = ... + return DataLoader(self.val, batch_size=64) + + def test_dataloader(self): + transforms = ... + return DataLoader(self.test, batch_size=64) + +Using DataModules allows easier sharing of full dataset definitions. + +.. code-block:: python + + # use an MNIST dataset + mnist_dm = MNISTDatamodule() + model = LitModel(num_classes=mnist_dm.num_classes) + trainer.fit(model, mnist_dm) + + # or other datasets with the same model + imagenet_dm = ImagenetDatamodule() + model = LitModel(num_classes=imagenet_dm.num_classes) + trainer.fit(model, imagenet_dm) + +.. note:: ``prepare_data()`` is called on only one GPU in distributed training (automatically) +.. note:: ``setup()`` is called on every GPU (automatically) + +Models defined by data +^^^^^^^^^^^^^^^^^^^^^^ +When your models need to know about the data, it's best to process the data before passing it to the model. + +.. code-block:: python + + # init dm AND call the processing manually + dm = ImagenetDataModule() + dm.prepare_data() + dm.setup() + + model = LitModel(out_features=dm.num_classes, img_width=dm.img_width, img_height=dm.img_height) + trainer.fit(model, dm) + + +1. use ``prepare_data()`` to download and process the dataset. +2. use ``setup()`` to do splits, and build your model internals + +| + +An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows: + +.. testcode:: + + class LitMNIST(LightningModule): + + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + def setup(self, stage: Optional[str] = None): + # step is either 'fit', 'validate', 'test', or 'predict'. 90% of the time not relevant + data = load_data() + num_classes = data.classes + self.l1 = nn.Linear(..., num_classes) + +Optimizer +--------- + +Next we choose what optimizer to use for training our system. +In PyTorch we do it as follows: + +.. code-block:: python + + from torch.optim import Adam + optimizer = Adam(LitMNIST().parameters(), lr=1e-3) + + +In Lightning we do the same but organize it under the :func:`~pytorch_lightning.core.LightningModule.configure_optimizers` method. + +.. testcode:: + + class LitMNIST(LightningModule): + + def configure_optimizers(self): + return Adam(self.parameters(), lr=1e-3) + +.. note:: The LightningModule itself has the parameters, so pass in self.parameters() + +However, if you have multiple optimizers use the matching parameters + +.. testcode:: + + class LitMNIST(LightningModule): + + def configure_optimizers(self): + return Adam(self.generator(), lr=1e-3), Adam(self.discriminator(), lr=1e-3) + + +Training step +------------- + +The training step is what happens inside the training loop. + +.. code-block:: python + + for epoch in epochs: + for batch in data: + # TRAINING STEP + # .... + # TRAINING STEP + optimizer.zero_grad() + loss.backward() + optimizer.step() + +In the case of MNIST, we do the following + +.. code-block:: python + + for epoch in epochs: + for batch in data: + # ------ TRAINING STEP START ------ + x, y = batch + logits = model(x) + loss = F.nll_loss(logits, y) + # ------ TRAINING STEP END ------ + + optimizer.zero_grad() + loss.backward() + optimizer.step() + +In Lightning, everything that is in the training step gets organized under the +:func:`~pytorch_lightning.core.LightningModule.training_step` function in the LightningModule. + +.. testcode:: + + class LitMNIST(LightningModule): + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + return loss + +Again, this is the same PyTorch code except that it has been organized by the LightningModule. +This code is not restricted which means it can be as complicated as a full seq-2-seq, RL loop, GAN, etc... + +---------------- + +The engineering +=============== + +Training +-------- +So far we defined 4 key ingredients in pure PyTorch but organized the code with the LightningModule. + +1. Model. +2. Training data. +3. Optimizer. +4. What happens in the training loop. + +| + +For clarity, we'll recall that the full LightningModule now looks like this. + +.. code-block:: python + + class LitMNIST(LightningModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(28 * 28, 128) + self.layer_2 = nn.Linear(128, 256) + self.layer_3 = nn.Linear(256, 10) + + def forward(self, x): + batch_size, channels, width, height = x.size() + x = x.view(batch_size, -1) + x = self.layer_1(x) + x = F.relu(x) + x = self.layer_2(x) + x = F.relu(x) + x = self.layer_3(x) + x = F.log_softmax(x, dim=1) + return x + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + return loss + +Again, this is the same PyTorch code, except that it's organized by the LightningModule. + +Logging +^^^^^^^ +To log to Tensorboard, your favorite logger, and/or the progress bar, use the +:func:`~~pytorch_lightning.core.lightning.LightningModule.log` method which can be called from +any method in the LightningModule. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_metric', x) + +The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options: + +- on_step (logs the metric at that step in training) +- on_epoch (automatically accumulates and logs at the end of the epoch) +- prog_bar (logs to the progress bar) +- logger (logs to the logger like Tensorboard) + +Depending on where the log is called from, Lightning auto-determines the correct mode for you. But of course +you can override the default behavior by manually setting the flags. + +.. note:: Setting on_epoch=True will accumulate your logged values over the full training epoch. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + +You can also use any method of your logger directly: + +.. code-block:: python + + def training_step(self, batch, batch_idx): + tensorboard = self.logger.experiment + tensorboard.any_summary_writer_method_you_want()) + +Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs: + +.. code-block:: bash + + tensorboard --logdir ./lightning_logs + + +Which will generate automatic tensorboard logs (or with the logger of your choice). + +.. figure:: ../_static/images/mnist_imgs/mnist_tb.png + :alt: mnist CPU bar + :width: 500 + +| + +But you can also use any of the :doc:`number of other loggers <../common/loggers>` we support. + + +Train on CPU +^^^^^^^^^^^^ +.. code-block:: python + + from pytorch_lightning import Trainer + + model = LitMNIST() + trainer = Trainer() + trainer.fit(model, train_loader) + +You should see the following weights summary and progress bar + +.. figure:: ../_static/images/mnist_imgs/mnist_cpu_bar.png + :alt: mnist CPU bar + + +Train on GPU +^^^^^^^^^^^^ +But the beauty is all the magic you can do with the trainer flags. For instance, to run this model on a GPU: + +.. code-block:: python + + model = LitMNIST() + trainer = Trainer(gpus=1) + trainer.fit(model, train_loader) + + +.. figure:: ../_static/images/mnist_imgs/mnist_gpu.png + :alt: mnist GPU bar + +Train on Multi-GPU +^^^^^^^^^^^^^^^^^^ +Or you can also train on multiple GPUs. + +.. code-block:: python + + model = LitMNIST() + trainer = Trainer(gpus=8) + trainer.fit(model, train_loader) + +Or multiple nodes + +.. code-block:: python + + # (32 GPUs) + model = LitMNIST() + trainer = Trainer(gpus=8, num_nodes=4, accelerator='ddp') + trainer.fit(model, train_loader) + +Refer to the :doc:`distributed computing guide for more details <../advanced/multi_gpu>`. + +Train on TPUs +^^^^^^^^^^^^^ +Did you know you can use PyTorch on TPUs? It's very hard to do, but we've +worked with the xla team to use their awesome library to get this to work +out of the box! + +Let's train on Colab (`full demo available here `_) + +First, change the runtime to TPU (and reinstall lightning). + +.. figure:: ../_static/images/mnist_imgs/runtime_tpu.png + :alt: mnist GPU bar + :width: 400 + +.. figure:: ../_static/images/mnist_imgs/restart_runtime.png + :alt: mnist GPU bar + :width: 400 + +| + +Next, install the required xla library (adds support for PyTorch on TPUs) + +.. code-block:: shell + + !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py + + !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev + +In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy +of this program. This means that without taking any care you will download the dataset N times which +will cause all sorts of issues. + +To solve this problem, make sure your download code is in the ``prepare_data`` method in the DataModule. +In this method we do all the preparation we need to do once (instead of on every GPU). + +``prepare_data`` can be called in two ways, once per node or only on the root node +(``Trainer(prepare_data_per_node=False)``). + +.. code-block:: python + + class MNISTDataModule(LightningDataModule): + def __init__(self, batch_size=64): + super().__init__() + self.batch_size = batch_size + + def prepare_data(self): + # download only + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) + + def setup(self, stage: Optional[str] = None): + # transform + transform=transforms.Compose([transforms.ToTensor()]) + mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform) + mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transform) + + # train/val split + mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) + + # assign to use in dataloaders + self.train_dataset = mnist_train + self.val_dataset = mnist_val + self.test_dataset = mnist_test + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.batch_size) + +The ``prepare_data`` method is also a good place to do any data processing that needs to be done only +once (ie: download or tokenize, etc...). + +.. note:: Lightning inserts the correct DistributedSampler for distributed training. No need to add yourself! + +Now we can train the LightningModule on a TPU without doing anything else! + +.. code-block:: python + + dm = MNISTDataModule() + model = LitMNIST() + trainer = Trainer(tpu_cores=8) + trainer.fit(model, dm) + +You'll now see the TPU cores booting up. + +.. figure:: ../_static/images/mnist_imgs/tpu_start.png + :alt: TPU start + :width: 400 + +Notice the epoch is MUCH faster! + +.. figure:: ../_static/images/mnist_imgs/tpu_fast.png + :alt: TPU speed + :width: 600 + +---------------- + +.. include:: ../common/hyperparameters.rst + +---------------- + +Validating +---------- + +For most cases, we stop training the model when the performance on a validation +split of the data reaches a minimum. + +Just like the ``training_step``, we can define a ``validation_step`` to check whatever +metrics we care about, generate samples, or add more to our logs. + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + loss = MSE_loss(...) + self.log('val_loss', loss) + +Now we can train with a validation loop as well. + +.. code-block:: python + + from pytorch_lightning import Trainer + + model = LitMNIST() + trainer = Trainer(tpu_cores=8) + trainer.fit(model, train_loader, val_loader) + +You may have noticed the words **Validation sanity check** logged. This is because Lightning runs 2 batches +of validation before starting to train. This is a kind of unit test to make sure that if you have a bug +in the validation loop, you won't need to potentially wait for a full epoch to find out. + +.. note:: Lightning disables gradients, puts model in eval mode, and does everything needed for validation. + +Val loop under the hood +^^^^^^^^^^^^^^^^^^^^^^^ +Under the hood, Lightning does the following: + +.. code-block:: python + + model = Model() + model.train() + torch.set_grad_enabled(True) + + for epoch in epochs: + for batch in data: + # ... + # train + + # validate + model.eval() + torch.set_grad_enabled(False) + + outputs = [] + for batch in val_data: + x, y = batch # validation_step + y_hat = model(x) # validation_step + loss = loss(y_hat, x) # validation_step + outputs.append({'val_loss': loss}) # validation_step + + total_loss = outputs.mean() # validation_epoch_end + +Optional methods +^^^^^^^^^^^^^^^^ +If you still need even more fine-grain control, define the other optional methods for the loop. + +.. code-block:: python + + def validation_step(self, batch, batch_idx): + preds = ... + return preds + + def validation_epoch_end(self, val_step_outputs): + for pred in val_step_outputs: + # do something with all the predictions from each validation_step + +---------------- + +Testing +------- +Once our research is done and we're about to publish or deploy a model, we normally want to figure out +how it will generalize in the "real world." For this, we use a held-out split of the data for testing. + +Just like the validation loop, we define a test loop + +.. code-block:: python + + class LitMNIST(LightningModule): + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + self.log('test_loss', loss) + + +However, to make sure the test set isn't used inadvertently, Lightning has a separate API to run tests. +Once you train your model simply call ``.test()``. + +.. code-block:: python + + from pytorch_lightning import Trainer + + model = LitMNIST() + trainer = Trainer(tpu_cores=8) + trainer.fit(model) + + # run test set + result = trainer.test() + print(result) + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + -------------------------------------------------------------- + TEST RESULTS + {'test_loss': 1.1703} + -------------------------------------------------------------- + +You can also run the test from a saved lightning model + +.. code-block:: python + + model = LitMNIST.load_from_checkpoint(PATH) + trainer = Trainer(tpu_cores=8) + trainer.test(model) + +.. note:: Lightning disables gradients, puts model in eval mode, and does everything needed for testing. + +.. warning:: .test() is not stable yet on TPUs. We're working on getting around the multiprocessing challenges. + +---------------- + +Predicting +---------- +Again, a LightningModule is exactly the same as a PyTorch module. This means you can load it +and use it for prediction. + +.. code-block:: python + + model = LitMNIST.load_from_checkpoint(PATH) + x = torch.randn(1, 1, 28, 28) + out = model(x) + +On the surface, it looks like ``forward`` and ``training_step`` are similar. Generally, we want to make sure that +what we want the model to do is what happens in the ``forward``. whereas the ``training_step`` likely calls forward from +within it. + +.. testcode:: + + class MNISTClassifier(LightningModule): + + def forward(self, x): + batch_size, channels, width, height = x.size() + x = x.view(batch_size, -1) + x = self.layer_1(x) + x = F.relu(x) + x = self.layer_2(x) + x = F.relu(x) + x = self.layer_3(x) + x = F.log_softmax(x, dim=1) + return x + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + return loss + +.. code-block:: python + + model = MNISTClassifier() + x = mnist_image() + logits = model(x) + +In this case, we've set this LightningModel to predict logits. But we could also have it predict feature maps: + +.. testcode:: + + class MNISTRepresentator(LightningModule): + + def forward(self, x): + batch_size, channels, width, height = x.size() + x = x.view(batch_size, -1) + x = self.layer_1(x) + x1 = F.relu(x) + x = self.layer_2(x1) + x2 = F.relu(x) + x3 = self.layer_3(x2) + return [x, x1, x2, x3] + + def training_step(self, batch, batch_idx): + x, y = batch + out, l1_feats, l2_feats, l3_feats = self(x) + logits = F.log_softmax(out, dim=1) + ce_loss = F.nll_loss(logits, y) + loss = perceptual_loss(l1_feats, l2_feats, l3_feats) + ce_loss + return loss + +.. code-block:: python + + model = MNISTRepresentator.load_from_checkpoint(PATH) + x = mnist_image() + feature_maps = model(x) + +Or maybe we have a model that we use to do generation + +.. testcode:: + + class LitMNISTDreamer(LightningModule): + + def forward(self, z): + imgs = self.decoder(z) + return imgs + + def training_step(self, batch, batch_idx): + x, y = batch + representation = self.encoder(x) + imgs = self(representation) + + loss = perceptual_loss(imgs, x) + return loss + +.. code-block:: python + + model = LitMNISTDreamer.load_from_checkpoint(PATH) + z = sample_noise() + generated_imgs = model(z) + + +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function +By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic. + +.. code-block:: python + + class LitMNISTDreamer(LightningModule): + + def forward(self, z): + imgs = self.decoder(z) + return imgs + + def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None): + return self(batch) + + + model = LitMNISTDreamer() + trainer.predict(model, datamodule) + + +How you split up what goes in ``forward`` vs ``training_step`` vs ``predict`` depends on how you want to use this model for +prediction. +However, we recommend ``forward`` to contain only tensor operation with your model, ``training_step`` to encapsulate ``forward`` logic with logging, +metrics and loss computation and ``predict`` to encapsulate ``forward`` with preprocess, postprocess functions. + +---------------- + +The nonessentials +================== + +Extensibility +------------- +Although lightning makes everything super simple, it doesn't sacrifice any flexibility or control. +Lightning offers multiple ways of managing the training state. + +Training overrides +^^^^^^^^^^^^^^^^^^ + +Any part of the training, validation, and testing loop can be modified. +For instance, if you wanted to do your own backward pass, you would override the +default implementation + +.. testcode:: + + def backward(self, use_amp, loss, optimizer): + loss.backward() + +With your own + +.. testcode:: + + class LitMNIST(LightningModule): + + def backward(self, use_amp, loss, optimizer, optimizer_idx): + # do a custom way of backward + loss.backward(retain_graph=True) + +Every single part of training is configurable this way. +For a full list look at :doc:`LightningModule <../common/lightning_module>`. + +---------------- + + +Callbacks +--------- +Another way to add arbitrary functionality is to add a custom callback +for hooks that you might care about + +.. testcode:: + + from pytorch_lightning.callbacks import Callback + + class MyPrintingCallback(Callback): + + def on_init_start(self, trainer): + print('Starting to init trainer!') + + def on_init_end(self, trainer): + print('Trainer is init now') + + def on_train_end(self, trainer, pl_module): + print('do something when training ends') + +And pass the callbacks into the trainer + +.. testcode:: + + trainer = Trainer(callbacks=[MyPrintingCallback()]) + +.. testoutput:: + :hide: + + Starting to init trainer! + Trainer is init now + +.. tip:: + See full list of 12+ hooks in the :doc:`callbacks <../extensions/callbacks>`. + +---------------- + +.. include:: ../common/child_modules.rst + +---------------- + +.. include:: ../advanced/transfer_learning.rst + +---------- + +********************* +Why PyTorch Lightning +********************* + +a. Less boilerplate +=================== + +Research and production code starts with simple code, but quickly grows in complexity +once you add GPU training, 16-bit, checkpointing, logging, etc... + +PyTorch Lightning implements these features for you and tests them rigorously to make sure you can +instead focus on the research idea. + +Writing less engineering/bolierplate code means: + +- fewer bugs +- faster iteration +- faster prototyping + +b. More functionality +===================== + +In PyTorch Lightning you leverage code written by hundreds of AI researchers, +research engs and PhDs from the world's top AI labs, +implementing all the latest best practices and SOTA features such as + +- GPU, Multi GPU, TPU training +- Multi-node training +- Auto logging +- ... +- Gradient accumulation + +c. Less error-prone +=================== + +Why re-invent the wheel? + +Use PyTorch Lightning to enjoy a deep learning structure that is rigorously tested (500+ tests) +across CPUs/multi-GPUs/multi-TPUs on every pull-request. + +We promise our collective team of 20+ from the top labs has thought about training more than you :) + +d. Not a new library +==================== + +PyTorch Lightning is organized PyTorch - no need to learn a new framework. + +Switching your model to Lightning is straight forward - here's a 2-minute video on how to do it. + +.. raw:: html + + + +Your projects WILL grow in complexity and you WILL end up engineering more than trying out new ideas... +Defer the hardest parts to Lightning! + +---------------- + +******************** +Lightning Philosophy +******************** +Lightning structures your deep learning code in 4 parts: + +- Research code +- Engineering code +- Non-essential code +- Data code + +Research code +============= +In the MNIST generation example, the research code +would be the particular system and how it's trained (ie: A GAN or VAE or GPT). + +.. code-block:: python + + l1 = nn.Linear(...) + l2 = nn.Linear(...) + decoder = Decoder() + + x1 = l1(x) + x2 = l2(x2) + out = decoder(features, x) + + loss = perceptual_loss(x1, x2, x) + CE(out, x) + +In Lightning, this code is organized into a :doc:`lightning module <../common/lightning_module>`. + +Engineering code +================ + +The Engineering code is all the code related to training this system. Things such as early stopping, distribution +over GPUs, 16-bit precision, etc. This is normally code that is THE SAME across most projects. + +.. code-block:: python + + model.cuda(0) + x = x.cuda(0) + + distributed = DistributedParallel(model) + + with gpu_zero: + download_data() + + dist.barrier() + +In Lightning, this code is abstracted out by the :doc:`trainer <../common/lightning_module>`. + +Non-essential code +================== + +This is code that helps the research but isn't relevant to the research code. Some examples might be: + +1. Inspect gradients +2. Log to tensorboard. + +| + +.. code-block:: python + + # log samples + z = Q.rsample() + generated = decoder(z) + self.experiment.log('images', generated) + +In Lightning this code is organized into :doc:`callbacks <../extensions/callbacks>`. + +Data code +========= +Lightning uses standard PyTorch DataLoaders or anything that gives a batch of data. +This code tends to end up getting messy with transforms, normalization constants, and data splitting +spread all over files. + +.. code-block:: python + + # data + train = MNIST(...) + train, val = split(train, val) + test = MNIST(...) + + # transforms + train_transforms = ... + val_transforms = ... + test_transforms = ... + + # dataloader ... + # download with dist.barrier() for multi-gpu, etc... + +This code gets especially complicated once you start doing multi-GPU training or needing info about +the data to build your models. + +In Lightning this code is organized inside a :doc:`datamodules <../extensions/datamodules>`. + +.. tip:: DataModules are optional but encouraged, otherwise you can use standard DataLoaders diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst new file mode 100644 index 00000000000000..7a1164b1bdf3a1 --- /dev/null +++ b/docs/source/starter/new-project.rst @@ -0,0 +1,789 @@ +.. testsetup:: * + + import os + import torch + from torch.nn import functional as F + from torch.utils.data import DataLoader + from torch.utils.data import random_split + import pytorch_lightning as pl + from pytorch_lightning.core.datamodule import LightningDataModule + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + +.. _new_project: + +#################### +Lightning in 2 steps +#################### + +**In this guide we'll show you how to organize your PyTorch code into Lightning in 2 steps.** + +Organizing your code with PyTorch Lightning makes your code: + +* Keep all the flexibility (this is all pure PyTorch), but removes a ton of boilerplate +* More readable by decoupling the research code from the engineering +* Easier to reproduce +* Less error-prone by automating most of the training loop and tricky engineering +* Scalable to any hardware without changing your model + +---------- + +Here's a 3 minute conversion guide for PyTorch projects: + +.. raw:: html + + + +---------- + +********************************* +Step 0: Install PyTorch Lightning +********************************* + + +You can install using `pip `_ + +.. code-block:: bash + + pip install pytorch-lightning + +Or with `conda `_ (see how to install conda `here `_): + +.. code-block:: bash + + conda install pytorch-lightning -c conda-forge + +You could also use conda environments + +.. code-block:: bash + + conda activate my_env + pip install pytorch-lightning + +---------- + +Import the following: + +.. testcode:: + :skipif: not _TORCHVISION_AVAILABLE + + import os + import torch + from torch import nn + import torch.nn.functional as F + from torchvision import transforms + from torchvision.datasets import MNIST + from torch.utils.data import DataLoader, random_split + import pytorch_lightning as pl + +****************************** +Step 1: Define LightningModule +****************************** + +.. testcode:: + + class LitAutoEncoder(pl.LightningModule): + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(28*28, 64), + nn.ReLU(), + nn.Linear(64, 3) + ) + self.decoder = nn.Sequential( + nn.Linear(3, 64), + nn.ReLU(), + nn.Linear(64, 28*28) + ) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + + def training_step(self, batch, batch_idx): + # training_step defined the train loop. + # It is independent of forward + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + # Logging to TensorBoard by default + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + +**SYSTEM VS MODEL** + +A :doc:`lightning module <../common/lightning_module>` defines a *system* not a model. + +.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png + :width: 400 + +Examples of systems are: + +- `Autoencoder `_ +- `BERT `_ +- `DQN `_ +- `GAN `_ +- `Image classifier `_ +- Seq2seq +- `SimCLR `_ +- `VAE `_ + +Under the hood a LightningModule is still just a :class:`torch.nn.Module` that groups all research code into a single file to make it self-contained: + +- The Train loop +- The Validation loop +- The Test loop +- The Model or system of Models +- The Optimizer + +You can customize any part of training (such as the backward pass) by overriding any +of the 20+ hooks found in :ref:`hooks` + +.. testcode:: + + class LitAutoEncoder(LightningModule): + + def backward(self, loss, optimizer, optimizer_idx): + loss.backward() + +**FORWARD vs TRAINING_STEP** + +In Lightning we separate training from inference. The training_step defines +the full training loop. We encourage users to use the forward to define inference +actions. + +For example, in this case we could define the autoencoder to act as an embedding extractor: + +.. code-block:: python + + def forward(self, x): + embeddings = self.encoder(x) + return embeddings + +Of course, nothing is stopping you from using forward from within the training_step. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + ... + z = self(x) + +It really comes down to your application. We do, however, recommend that you keep both intents separate. + +* Use forward for inference (predicting). +* Use training_step for training. + +More details in :doc:`lightning module <../common/lightning_module>` docs. + +---------- + +********************************** +Step 2: Fit with Lightning Trainer +********************************** + +First, define the data however you want. Lightning just needs a :class:`~torch.utils.data.DataLoader` for the train/val/test splits. + +.. code-block:: python + + dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) + train_loader = DataLoader(dataset) + +Next, init the :doc:`lightning module <../common/lightning_module>` and the PyTorch Lightning :class:`~pytorch_lightning.trainer.Trainer`, +then call fit with both the data and model. + +.. code-block:: python + + # init model + autoencoder = LitAutoEncoder() + + # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) + # trainer = pl.Trainer(gpus=8) (if you have GPUs) + trainer = pl.Trainer() + trainer.fit(autoencoder, train_loader) + +The :class:`~pytorch_lightning.trainer.Trainer` automates: + +* Epoch and batch iteration +* Calling of optimizer.step(), backward, zero_grad() +* Calling of .eval(), enabling/disabling grads +* :doc:`weights loading <../common/weights_loading>` +* Tensorboard (see :doc:`loggers <../common/loggers>` options) +* :doc:`Multi-GPU <../advanced/multi_gpu>` support +* :doc:`TPU <../advanced/tpu>` +* :doc:`AMP <../advanced/amp>` support + +.. tip:: If you prefer to manually manage optimizers you can use the :ref:`manual_opt` mode (ie: RL, GANs, etc...). + + +--------- + +**That's it!** + +These are the main 2 concepts you need to know in Lightning. All the other features of lightning are either +features of the Trainer or LightningModule. + +----------- + +************** +Basic features +************** + +Manual vs automatic optimization +================================ + +Automatic optimization +---------------------- +With Lightning, you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers +as long as you return a loss with an attached graph from the `training_step`, Lightning will automate the optimization. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + loss = self.encoder(batch) + return loss + +.. _manual_opt: + +Manual optimization +------------------- +However, for certain research like GANs, reinforcement learning, or something with multiple optimizers +or an inner loop, you can turn off automatic optimization and fully control the training loop yourself. + +Turn off automatic optimization and you control the train loop! + +.. code-block:: python + + def __init__(self): + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # access your optimizers with use_pl_optimizer=False. Default is True + opt_a, opt_b = self.optimizers(use_pl_optimizer=True) + + loss_a = self.generator(batch) + opt_a.zero_grad() + # use `manual_backward()` instead of `loss.backward` to automate half precision, etc... + self.manual_backward(loss_a) + opt_a.step() + + loss_b = self.discriminator(batch) + opt_b.zero_grad() + self.manual_backward(loss_b) + opt_b.step() + + +Predict or Deploy +================= +When you're done training, you have 3 options to use your LightningModule for predictions. + +Option 1: Sub-models +-------------------- +Pull out any model inside your system for predictions. + +.. code-block:: python + + # ---------------------------------- + # to use as embedding extractor + # ---------------------------------- + autoencoder = LitAutoEncoder.load_from_checkpoint('path/to/checkpoint_file.ckpt') + encoder_model = autoencoder.encoder + encoder_model.eval() + + # ---------------------------------- + # to use as image generator + # ---------------------------------- + decoder_model = autoencoder.decoder + decoder_model.eval() + +Option 2: Forward +----------------- +You can also add a forward method to do predictions however you want. + +.. testcode:: + + # ---------------------------------- + # using the AE to extract embeddings + # ---------------------------------- + class LitAutoEncoder(LightningModule): + def __init__(self): + super().__init__() + self.encoder = nn.Sequential() + + def forward(self, x): + embedding = self.encoder(x) + return embedding + + autoencoder = LitAutoEncoder() + autoencoder = autoencoder(torch.rand(1, 28 * 28)) + + +.. code-block:: python + + # ---------------------------------- + # or using the AE to generate images + # ---------------------------------- + class LitAutoEncoder(LightningModule): + def __init__(self): + super().__init__() + self.decoder = nn.Sequential() + + def forward(self): + z = torch.rand(1, 3) + image = self.decoder(z) + image = image.view(1, 1, 28, 28) + return image + + autoencoder = LitAutoEncoder() + image_sample = autoencoder() + +Option 3: Production +-------------------- +For production systems, onnx or torchscript are much faster. Make sure you have added +a forward method or trace only the sub-models you need. + +.. code-block:: python + + # ---------------------------------- + # torchscript + # ---------------------------------- + autoencoder = LitAutoEncoder() + torch.jit.save(autoencoder.to_torchscript(), "model.pt") + os.path.isfile("model.pt") + +.. code-block:: python + + # ---------------------------------- + # onnx + # ---------------------------------- + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: + autoencoder = LitAutoEncoder() + input_sample = torch.randn((1, 28 * 28)) + autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True) + os.path.isfile(tmpfile.name) + +-------------------- + +Using CPUs/GPUs/TPUs +==================== +It's trivial to use CPUs, GPUs or TPUs in Lightning. There's **NO NEED** to change your code, simply change the :class:`~pytorch_lightning.trainer.Trainer` options. + +.. testcode:: + + # train on CPU + trainer = Trainer() + +.. testcode:: + + # train on 8 CPUs + trainer = Trainer(num_processes=8) + +.. code-block:: python + + # train on 1024 CPUs across 128 machines + trainer = pl.Trainer( + num_processes=8, + num_nodes=128 + ) + +.. code-block:: python + + # train on 1 GPU + trainer = pl.Trainer(gpus=1) + +.. code-block:: python + + # train on multiple GPUs across nodes (32 gpus here) + trainer = pl.Trainer( + gpus=4, + num_nodes=8 + ) + +.. code-block:: python + + # train on gpu 1, 3, 5 (3 gpus total) + trainer = pl.Trainer(gpus=[1, 3, 5]) + +.. code-block:: python + + # Multi GPU with mixed precision + trainer = pl.Trainer(gpus=2, precision=16) + +.. code-block:: python + + # Train on TPUs + trainer = pl.Trainer(tpu_cores=8) + +Without changing a SINGLE line of your code, you can now do the following with the above code: + +.. code-block:: python + + # train on TPUs using 16 bit precision + # using only half the training data and checking validation every quarter of a training epoch + trainer = pl.Trainer( + tpu_cores=8, + precision=16, + limit_train_batches=0.5, + val_check_interval=0.25 + ) + +----------- + +Checkpoints +=========== +Lightning automatically saves your model. Once you've trained, you can load the checkpoints as follows: + +.. code-block:: python + + model = LitModel.load_from_checkpoint(path) + +The above checkpoint contains all the arguments needed to init the model and set the state dict. +If you prefer to do it manually, here's the equivalent + +.. code-block:: python + + # load the ckpt + ckpt = torch.load('path/to/checkpoint.ckpt') + + # equivalent to the above + model = LitModel() + model.load_state_dict(ckpt['state_dict']) + +--------- + +Data flow +========= +Each loop (training, validation, test) has three hooks you can implement: + +- x_step +- x_step_end +- x_epoch_end + +To illustrate how data flows, we'll use the training loop (ie: x=training) + +.. code-block:: python + + outs = [] + for batch in data: + out = training_step(batch) + outs.append(out) + training_epoch_end(outs) + +The equivalent in Lightning is: + +.. code-block:: python + + def training_step(self, batch, batch_idx): + prediction = ... + return prediction + + def training_epoch_end(self, training_step_outputs): + for prediction in predictions: + # do something with these + +In the event that you use DP or DDP2 distributed modes (ie: split a batch across GPUs), +use the x_step_end to manually aggregate (or don't implement it to let lightning auto-aggregate for you). + +.. code-block:: python + + for batch in data: + model_copies = copy_model_per_gpu(model, num_gpus) + batch_split = split_batch_per_gpu(batch, num_gpus) + + gpu_outs = [] + for model, batch_part in zip(model_copies, batch_split): + # LightningModule hook + gpu_out = model.training_step(batch_part) + gpu_outs.append(gpu_out) + + # LightningModule hook + out = training_step_end(gpu_outs) + +The lightning equivalent is: + +.. code-block:: python + + def training_step(self, batch, batch_idx): + loss = ... + return loss + + def training_step_end(self, losses): + gpu_0_loss = losses[0] + gpu_1_loss = losses[1] + return (gpu_0_loss + gpu_1_loss) * 1/2 + +.. tip:: The validation and test loops have the same structure. + +----------------- + +Logging +======= +To log to Tensorboard, your favorite logger, and/or the progress bar, use the +:func:`~~pytorch_lightning.core.lightning.LightningModule.log` method which can be called from +any method in the LightningModule. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_metric', x) + +The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options: + +- on_step (logs the metric at that step in training) +- on_epoch (automatically accumulates and logs at the end of the epoch) +- prog_bar (logs to the progress bar) +- logger (logs to the logger like Tensorboard) + +Depending on where the log is called from, Lightning auto-determines the correct mode for you. But of course +you can override the default behavior by manually setting the flags + +.. note:: Setting on_epoch=True will accumulate your logged values over the full training epoch. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + +.. note:: + The loss value shown in the progress bar is smoothed (averaged) over the last values, + so it differs from the actual loss returned in the train/validation step. + +You can also use any method of your logger directly: + +.. code-block:: python + + def training_step(self, batch, batch_idx): + tensorboard = self.logger.experiment + tensorboard.any_summary_writer_method_you_want()) + +Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs: + +.. code-block:: bash + + tensorboard --logdir ./lightning_logs + +.. note:: + Lightning automatically shows the loss value returned from ``training_step`` in the progress bar. + So, no need to explicitly log like this ``self.log('loss', loss, prog_bar=True)``. + +Read more about :doc:`loggers <../common/loggers>`. + +---------------- + +Optional extensions +=================== + +Callbacks +--------- +A callback is an arbitrary self-contained program that can be executed at arbitrary parts of the training loop. + +Here's an example adding a not-so-fancy learning rate decay rule: + +.. testcode:: + + from pytorch_lightning.callbacks import Callback + + class DecayLearningRate(Callback): + + def __init__(self): + self.old_lrs = [] + + def on_train_start(self, trainer, pl_module): + # track the initial learning rates + for opt_idx, optimizer in enumerate(trainer.optimizers): + group = [param_group['lr'] for param_group in optimizer.param_groups] + self.old_lrs.append(group) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + for opt_idx, optimizer in enumerate(trainer.optimizers): + old_lr_group = self.old_lrs[opt_idx] + new_lr_group = [] + for p_idx, param_group in enumerate(optimizer.param_groups): + old_lr = old_lr_group[p_idx] + new_lr = old_lr * 0.98 + new_lr_group.append(new_lr) + param_group['lr'] = new_lr + self.old_lrs[opt_idx] = new_lr_group + + # And pass the callback to the Trainer + decay_callback = DecayLearningRate() + trainer = Trainer(callbacks=[decay_callback]) + +Things you can do with a callback: + +- Send emails at some point in training +- Grow the model +- Update learning rates +- Visualize gradients +- ... +- You are only limited by your imagination + +:doc:`Learn more about custom callbacks <../extensions/callbacks>`. + + +LightningDataModules +-------------------- +DataLoaders and data processing code tends to end up scattered around. +Make your data code reusable by organizing it into a :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + +.. testcode:: + + class MNISTDataModule(LightningDataModule): + + def __init__(self, batch_size=32): + super().__init__() + self.batch_size = batch_size + + # When doing distributed training, Datamodules have two optional arguments for + # granular control over download/prepare/splitting data: + + # OPTIONAL, called only on 1 GPU/machine + def prepare_data(self): + MNIST(os.getcwd(), train=True, download=True) + MNIST(os.getcwd(), train=False, download=True) + + # OPTIONAL, called for every GPU/machine (assigning state is OK) + def setup(self, stage: Optional[str] = None): + # transforms + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + # split dataset + if stage == 'fit': + mnist_train = MNIST(os.getcwd(), train=True, transform=transform) + self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000]) + if stage == 'test': + self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform) + + # return the dataloader for each split + def train_dataloader(self): + mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) + return mnist_train + + def val_dataloader(self): + mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size) + return mnist_val + + def test_dataloader(self): + mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size) + return mnist_test + +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` is designed to enable sharing and reusing data splits +and transforms across different projects. It encapsulates all the steps needed to process data: downloading, +tokenizing, processing etc. + +Now you can simply pass your :class:`~pytorch_lightning.core.datamodule.LightningDataModule` to +the :class:`~pytorch_lightning.trainer.Trainer`: + +.. code-block:: python + + # init model + model = LitModel() + + # init data + dm = MNISTDataModule() + + # train + trainer = pl.Trainer() + trainer.fit(model, dm) + + # test + trainer.test(datamodule=dm) + +DataModules are specifically useful for building models based on data. Read more on :doc:`datamodules <../extensions/datamodules>`. + +------ + +Debugging +========= +Lightning has many tools for debugging. Here is an example of just a few of them: + +.. testcode:: + + # use only 10 train batches and 3 val batches + trainer = Trainer(limit_train_batches=10, limit_val_batches=3) + +.. testcode:: + + # Automatically overfit the sane batch of your model for a sanity test + trainer = Trainer(overfit_batches=1) + +.. testcode:: + + # unit test all the code- hits every line of your code once to see if you have bugs, + # instead of waiting hours to crash on validation + trainer = Trainer(fast_dev_run=True) + +.. testcode:: + + # train only 20% of an epoch + trainer = Trainer(limit_train_batches=0.2) + +.. testcode:: + + # run validation every 25% of a training epoch + trainer = Trainer(val_check_interval=0.25) + +.. testcode:: + + # Profile your code to find speed/memory bottlenecks + Trainer(profiler="simple") + +--------------- + +******************** +Other coool features +******************** + +Once you define and train your first Lightning model, you might want to try other cool features like + +- :doc:`Automatic early stopping <../common/early_stopping>` +- :ref:`Automatic truncated-back-propagation-through-time ` +- :ref:`Automatically scale your batch size ` +- :doc:`Automatically find a good learning rate <../advanced/lr_finder>` +- :ref:`Load checkpoints directly from S3 ` +- :doc:`Scale to massive compute clusters <../clouds/slurm>` +- :doc:`Use multiple dataloaders per train/val/test loop <../advanced/multiple_loaders>` +- :ref:`Use multiple optimizers to do reinforcement learning or even GANs ` + +Or read our :doc:`Guide <../starter/introduction_guide>` to learn more! + +------------- + +Grid AI +======= +Grid AI is our native solution for large scale training and tuning on the cloud provider of your choice. + +`Click here to request early-access `_. + +------------ + +********** +Community +********** +Our community of core maintainers and thousands of expert researchers is active on our +`Slack `_ +and `GitHub Discussions `_. Drop by +to hang out, ask Lightning questions or even discuss research! + + +------------- + +Masterclass +=========== +We also offer a Masterclass to teach you the advanced uses of Lightning. + +.. image:: ../_static/images/general/PTL101_youtube_thumbnail.jpg + :width: 500 + :align: center + :alt: Masterclass + :target: https://www.youtube.com/playlist?list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2 diff --git a/docs/source/starter/rapid_prototyping_templates.rst b/docs/source/starter/rapid_prototyping_templates.rst new file mode 100644 index 00000000000000..5d760df961c366 --- /dev/null +++ b/docs/source/starter/rapid_prototyping_templates.rst @@ -0,0 +1,38 @@ +########################### +Rapid prototyping templates +########################### +Use these templates for rapid prototyping + +----------- + +*********** +General Use +*********** + +.. list-table:: + :widths: 18 15 25 + :header-rows: 1 + + * - Use case + - Description + - link + * - Scratch model + - To prototype quickly / debug with random data + - + .. raw:: html + +
+ + open in colab + +
+ * - Scratch model with manual optimization + - To prototype quickly / debug with random data + - + .. raw:: html + +
+ + open in colab + +
diff --git a/docs/source/starter/style_guide.rst b/docs/source/starter/style_guide.rst new file mode 100644 index 00000000000000..f922d900f80910 --- /dev/null +++ b/docs/source/starter/style_guide.rst @@ -0,0 +1,203 @@ +########### +Style guide +########### +A main goal of Lightning is to improve readability and reproducibility. Imagine looking into any GitHub repo, +finding a lightning module and knowing exactly where to look to find the things you care about. + +The goal of this style guide is to encourage Lightning code to be structured similarly. + +-------------- + +*************** +LightningModule +*************** +These are best practices about structuring your LightningModule + +Systems vs models +================= + +.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png + :width: 400 + +The main principle behind a LightningModule is that a full system should be self-contained. +In Lightning we differentiate between a system and a model. + +A model is something like a resnet18, RNN, etc. + +A system defines how a collection of models interact with each other. Examples of this are: + +* GANs +* Seq2Seq +* BERT +* etc + +A LightningModule can define both a system and a model. + +Here's a LightningModule that defines a model: + +.. testcode:: + + class LitModel(LightningModule): + def __init__(self, num_layers: int = 3): + super().__init__() + self.layer_1 = nn.Linear() + self.layer_2 = nn.Linear() + self.layer_3 = nn.Linear() + +Here's a LightningModule that defines a system: + +.. testcode:: + + class LitModel(LightningModule): + def __init__(self, encoder: nn.Module = None, decoder: nn.Module = None): + super().__init__() + self.encoder = encoder + self.decoder = decoder + +For fast prototyping it's often useful to define all the computations in a LightningModule. For reusability +and scalability it might be better to pass in the relevant backbones. + +Self-contained +============== +A Lightning module should be self-contained. A good test to see how self-contained your model is, is to ask +yourself this question: + +"Can someone drop this file into a Trainer without knowing anything about the internals?" + +For example, we couple the optimizer with a model because the majority of models require a specific optimizer with +a specific learning rate scheduler to work well. + +Init +==== +The first place where LightningModules tend to stop being self-contained is in the init. Try to define all the relevant +sensible defaults in the init so that the user doesn't have to guess. + +Here's an example where a user will have to go hunt through files to figure out how to init this LightningModule. + +.. testcode:: + + class LitModel(LightningModule): + def __init__(self, params): + self.lr = params.lr + self.coef_x = params.coef_x + +Models defined as such leave you with many questions; what is coef_x? is it a string? a float? what is the range? etc... + +Instead, be explicit in your init + +.. testcode:: + + class LitModel(LightningModule): + def __init__(self, encoder: nn.Module, coeff_x: float = 0.2, lr: float = 1e-3): + ... + +Now the user doesn't have to guess. Instead they know the value type and the model has a sensible default where the +user can see the value immediately. + + +Method order +============ +The only required methods in the LightningModule are: + +* init +* training_step +* configure_optimizers + +However, if you decide to implement the rest of the optional methods, the recommended order is: + +* model/system definition (init) +* if doing inference, define forward +* training hooks +* validation hooks +* test hooks +* configure_optimizers +* any other hooks + +In practice, this code looks like: + +.. code-block:: python + + class LitModel(pl.LightningModule): + + def __init__(...): + + def forward(...): + + def training_step(...) + + def training_step_end(...) + + def training_epoch_end(...) + + def validation_step(...) + + def validation_step_end(...) + + def validation_epoch_end(...) + + def test_step(...) + + def test_step_end(...) + + def test_epoch_end(...) + + def configure_optimizers(...) + + def any_extra_hook(...) + +Forward vs training_step +======================== +We recommend using forward for inference/predictions and keeping training_step independent + +.. code-block:: python + + def forward(...): + embeddings = self.encoder(x) + + def training_step(...): + x, y = ... + z = self.encoder(x) + pred = self.decoder(z) + ... + +However, when using DataParallel, you will need to call forward manually + +.. code-block:: python + + def training_step(...): + x, y = ... + z = self(x) # < ---------- instead of self.encoder(x) + pred = self.decoder(z) + ... + +-------------- + +**** +Data +**** +These are best practices for handling data. + +Dataloaders +=========== +Lightning uses dataloaders to handle all the data flow through the system. Whenever you structure dataloaders, +make sure to tune the number of workers for maximum efficiency. + +.. warning:: Make sure not to use ddp_spawn with num_workers > 0 or you will bottleneck your code. + +DataModules +=========== +Lightning introduced datamodules. The problem with dataloaders is that sharing full datasets is often still challenging +because all these questions need to be answered: + +* What splits were used? +* How many samples does this dataset have? +* What transforms were used? +* etc... + +It's for this reason that we recommend you use datamodules. This is specially important when collaborating because +it will save your team a lot of time as well. + +All they need to do is drop a datamodule into a lightning trainer and not worry about what was done to the data. + +This is true for both academic and corporate settings where data cleaning and ad-hoc instructions slow down the progress +of iterating through ideas. diff --git a/docs/source/test_set.rst b/docs/source/test_set.rst deleted file mode 100644 index 7dfe40ddaa2daf..00000000000000 --- a/docs/source/test_set.rst +++ /dev/null @@ -1,38 +0,0 @@ -Test set -======== -Lightning forces the user to run the test set separately to make sure it isn't evaluated by mistake - - -Test after fit --------------- -To run the test set after training completes, use this method - -.. code-block:: python - - # run full training - trainer.fit(model) - - # run test set - trainer.test() - -Test pre-trained model ----------------------- -To run the test set on a pre-trained model, use this method. - -.. code-block:: python - - model = MyLightningModule.load_from_metrics( - weights_path='/path/to/pytorch_checkpoint.ckpt', - tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv', - on_gpu=True, - map_location=None - ) - - # init trainer with whatever options - trainer = Trainer(...) - - # test (pass in the model) - trainer.test(model) - -In this case, the options you pass to trainer will be used when -running the test set (ie: 16-bit, dp, ddp, etc...) \ No newline at end of file diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst deleted file mode 100644 index b2fb6e8571e266..00000000000000 --- a/docs/source/tpu.rst +++ /dev/null @@ -1,210 +0,0 @@ -TPU support -=========== - -Lightning supports running on TPUs. At this moment, TPUs are only available -on Google Cloud (GCP). For more information on TPUs -`watch this video `_. - ---------------- - -Live demo ----------- -Check out this `Google Colab `_ to see how to train MNIST on TPUs. - ---------------- - -TPU Terminology ---------------- -A TPU is a Tensor processing unit. Each TPU has 8 cores where each -core is optimized for 128x128 matrix multiplies. In general, a single -TPU is about as fast as 5 V100 GPUs! - -A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! -You can request a full pod from Google cloud or a "slice" which gives you -some subset of those 2048 cores. - ---------------- - -How to access TPUs -------------------- -To access TPUs there are two main ways. - -1. Using google colab. -2. Using Google Cloud (GCP). - ---------------- - -Colab TPUs ------------ -Colab is like a jupyter notebook with a free GPU or TPU -hosted on GCP. - -To get a TPU on colab, follow these steps: - -1. Go to `https://colab.research.google.com/ `_. - -2. Click "new notebook" (bottom right of pop-up). - -3. Click runtime > change runtime settings. Select Python 3, and hardware accelerator "TPU". - This will give you a TPU with 8 cores. - -4. Next, insert this code into the first cell and execute. - This will install the xla library that interfaces between PyTorch and the TPU. - - .. code-block:: python - - import collections - from datetime import datetime, timedelta - import os - import requests - import threading - - _VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server') - VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"] - CONFIG = { - 'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'), - 'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format( - (datetime.today() - timedelta(1)).strftime('%Y%m%d'))), - }[VERSION] - DIST_BUCKET = 'gs://tpu-pytorch/wheels' - TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) - TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) - TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) - - # Update TPU XRT version - def update_server_xrt(): - print('Updating server-side XRT to {} ...'.format(CONFIG.server)) - url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format( - TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0], - XRT_VERSION=CONFIG.server, - ) - print('Done updating server-side XRT: {}'.format(requests.post(url))) - - update = threading.Thread(target=update_server_xrt) - update.start() - - .. code-block:: - - # Install Colab TPU compat PyTorch/TPU wheels and dependencies - !pip uninstall -y torch torchvision - !gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" . - !gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" . - !gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" . - !pip install "$TORCH_WHEEL" - !pip install "$TORCH_XLA_WHEEL" - !pip install "$TORCHVISION_WHEEL" - !sudo apt-get install libomp5 - update.join() - -5. Once the above is done, install PyTorch Lightning (v 0.7.0+). - - .. code-block:: - - !pip install pytorch-lightning - -6. Then set up your LightningModule as normal. - ---------------- - -DistributedSamplers -------------------- -Lightning automatically inserts the correct samplers - no need to do this yourself! - -Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right -chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning - -.. note:: Don't add distributedSamplers. Lightning does this automatically - -If for some reason you still need to, this is how to construct the sampler -for TPU use - -.. code-block:: python - - import torch_xla.core.xla_model as xm - - def train_dataloader(self): - dataset = MNIST( - os.getcwd(), - train=True, - download=True, - transform=transforms.ToTensor() - ) - - # required for TPU support - sampler = None - if use_tpu: - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - shuffle=True - ) - - loader = DataLoader( - dataset, - sampler=sampler, - batch_size=32 - ) - - return loader - -Configure the number of TPU cores in the trainer. You can only choose 1 or 8. -To use a full TPU pod skip to the TPU pod section. - -.. code-block:: python - - import pytorch_lightning as pl - - my_model = MyLightningModule() - trainer = pl.Trainer(num_tpu_cores=8) - trainer.fit(my_model) - -That's it! Your model will train on all 8 TPU cores. - ---------------- - -Distributed Backend with TPU ----------------------------- -The ```distributed_backend``` option used for GPUs does not apply to TPUs. -TPUs work in DDP mode by default (distributing over each core) - ---------------- - -TPU Pod --------- -To train on more than 8 cores, your code actually doesn't change! -All you need to do is submit the following command: - -.. code-block:: bash - - $ python -m torch_xla.distributed.xla_dist - --tpu=$TPU_POD_NAME - --conda-env=torch-xla-nightly - -- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data - ---------------- - -16 bit precision ------------------ -Lightning also supports training in 16-bit precision with TPUs. -By default, TPU training will use 32-bit precision. To enable 16-bit, also -set the 16-bit flag. - -.. code-block:: python - - import pytorch_lightning as pl - - my_model = MyLightningModule() - trainer = pl.Trainer(num_tpu_cores=8, precision=16) - trainer.fit(my_model) - -Under the hood the xla library will use the `bfloat16 type `_. - ---------------- - -About XLA ----------- -XLA is the library that interfaces PyTorch with the TPUs. -For more information check out `XLA `_. - -Guide for `troubleshooting XLA `_ diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst deleted file mode 100644 index 19c394db4854b3..00000000000000 --- a/docs/source/trainer.rst +++ /dev/null @@ -1,24 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Trainer -======= -.. automodule:: pytorch_lightning.trainer - :members: fit, test - :noindex: - :exclude-members: - run_pretrain_routine, - _abc_impl, - _Trainer__set_random_port, - _Trainer__set_root_gpu, - _Trainer__init_optimizers, - _Trainer__parse_gpu_ids, - _Trainer__configure_schedulers, - data_parallel, - num_gpus, - slurm_job_id, - tng_tqdm_dic, - training_tqdm_dict, - progress_bar_dict, - init_optimizers, - configure_schedulers diff --git a/docs/source/training_tricks.rst b/docs/source/training_tricks.rst deleted file mode 100644 index e97d7837e0eb4c..00000000000000 --- a/docs/source/training_tricks.rst +++ /dev/null @@ -1,36 +0,0 @@ -.. testsetup:: * - - from pytorch_lightning.trainer.trainer import Trainer - - -Training Tricks -================ -Lightning implements various tricks to help during training - -Accumulate gradients -------------------------------------- -Accumulated gradients runs K small batches of size N before doing a backwards pass. -The effect is a large effective batch size of size KxN. - -.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` - -.. testcode:: - - # DEFAULT (ie: no accumulated grads) - trainer = Trainer(accumulate_grad_batches=1) - - -Gradient Clipping -------------------------------------- -Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient -norm `_ computed over all model parameters together. - -.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` - -.. testcode:: - - # DEFAULT (ie: don't clip) - trainer = Trainer(gradient_clip_val=0) - - # clip gradients with norm above 0.5 - trainer = Trainer(gradient_clip_val=0.5) diff --git a/docs/source/weights_loading.rst b/docs/source/weights_loading.rst deleted file mode 100644 index 64a6950738ef1b..00000000000000 --- a/docs/source/weights_loading.rst +++ /dev/null @@ -1,141 +0,0 @@ -.. testsetup:: * - - import os - from pytorch_lightning.trainer.trainer import Trainer - from pytorch_lightning.core.lightning import LightningModule - - -Saving and loading weights -========================== - -Lightning can automate saving and loading checkpoints. - -Checkpoint saving ------------------ -A Lightning checkpoint has everything needed to restore a training session including: - -- 16-bit scaling factor (apex) -- Current epoch -- Global step -- Model state_dict -- State of all optimizers -- State of all learningRate schedulers -- State of all callbacks -- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) - -Automatic saving -^^^^^^^^^^^^^^^^ - -Checkpointing is enabled by default to the current working directory. -To change the checkpoint path pass in: - -.. testcode:: - - trainer = Trainer(default_save_path='/your/path/to/save/checkpoints') - -To modify the behavior of checkpointing pass in your own callback. - -.. testcode:: - - from pytorch_lightning.callbacks import ModelCheckpoint - - # DEFAULTS used by the Trainer - checkpoint_callback = ModelCheckpoint( - filepath=os.getcwd(), - save_top_k=True, - verbose=True, - monitor='val_loss', - mode='min', - prefix='' - ) - - trainer = Trainer(checkpoint_callback=checkpoint_callback) - - -Or disable it by passing - -.. testcode:: - - trainer = Trainer(checkpoint_callback=False) - - -The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init. - -.. note:: hparams is a `Namespace `_. - -.. testcode:: - - from argparse import Namespace - - # usually these come from command line args - args = Namespace(learning_rate=0.001) - - # define you module to have hparams as the first arg - # this means your checkpoint will have everything that went into making - # this model (in this case, learning rate) - class MyLightningModule(LightningModule): - - def __init__(self, hparams, *args, **kwargs): - self.hparams = hparams - -Manual saving -^^^^^^^^^^^^^ -You can manually save checkpoints and restore your model from the checkpointed state. - -.. code-block:: python - - model = MyLightningModule(hparams) - trainer.fit(model) - trainer.save_checkpoint("example.ckpt") - new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt") - -Checkpoint Loading ------------------- - -To load a model along with its weights, biases and hyperparameters use following method. - -.. code-block:: python - - model = MyLightingModule.load_from_checkpoint(PATH) - model.eval() - y_hat = model(x) - -The above only works if you used `hparams` in your model definition - -.. testcode:: - - class LitModel(LightningModule): - - def __init__(self, hparams): - self.hparams = hparams - self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim) - -But if you don't and instead pass individual parameters - -.. testcode:: - - class LitModel(LightningModule): - - def __init__(self, in_dim, out_dim): - self.l1 = nn.Linear(in_dim, out_dim) - -you can restore the model like this - -.. code-block:: python - - model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10) - - -Restoring Training State ------------------------- - -If you don't just want to load weights, but instead restore the full training, -do the following: - -.. code-block:: python - - model = LitModel() - trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') - - # automatically restores model, epoch, step, LR schedulers, apex, etc... - trainer.fit(model) diff --git a/environment.yml b/environment.yml index 45e0e3da307b80..a724bcf7ff12a1 100644 --- a/environment.yml +++ b/environment.yml @@ -1,35 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # This is Conda environment file # Usage: `conda env update -f environment.yml` +name: + lightning + channels: - conda-forge - pytorch + - pytorch-test + - pytorch-nightly dependencies: - - python==3.7.6 - - pip==20.0.2 - - tqdm>=4.35.0 + - python>=3.6 + - pip>20.1 - numpy>=1.16.4 - - pytorch>=1.1 - - tensorboard>=1.14 + - pytorch>=1.4 - future>=0.17.1 + - PyYAML>=5.1 + - tqdm>=4.41.0 + - fsspec[http]>=0.8.1 + #- tensorboard>=2.2.0 # not needed, already included in pytorch + + # Optional + #- nvidia-apex # missing for py3.8 + - scikit-learn>=0.20.0 + - matplotlib>=3.1.1 + - omegaconf>=2.0.0 + - torchtext>=0.5 - # For dev and testing - - tox - - coverage - - codecov - - pytest>=3.0.5 - - pytest-cov - - pytest-flake8 - - flake8 - - autopep8 - - check-manifest - - twine==1.13.0 + # Examples + - torchvision>=0.5 - pip: - test-tube>=0.7.5 - mlflow>=1.0.0 - - comet_ml>=1.0.56 + - comet_ml>=3.1.12 - wandb>=0.8.21 - - neptune-client>=0.4.4 - - trains>=0.13.3 + - neptune-client>=0.4.109 + - horovod>=0.21.2 + - onnxruntime>=1.3.0 + - gym>=0.17.0 diff --git a/legacy/README.md b/legacy/README.md new file mode 100644 index 00000000000000..3ce6d15f655680 --- /dev/null +++ b/legacy/README.md @@ -0,0 +1,17 @@ +# Maintaining back-compatibility with come legacy versions + +The aim of this section is set some baselines and workflows/guidelines for maintaining back compatibility with some legacies version of PL + +At this moment we focus on ability running old checkpoints, so the flow here is to create a checkpoint with every release and store it in our public AWS storage and so each CI testing will pull this archive and test loading and resuming training with this model. + +If you want to pull all saved version-checkpoints for local testing/development, call +```bash +wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip +unzip -o checkpoints.zip +``` + +To back populate collection with past version you can use following bash: +```bash +bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4 +zip -r checkpoints.zip checkpoints/ +``` diff --git a/pl_examples/models/__init__.py b/legacy/checkpoints/.gitkeep similarity index 100% rename from pl_examples/models/__init__.py rename to legacy/checkpoints/.gitkeep diff --git a/legacy/generate_checkpoints.sh b/legacy/generate_checkpoints.sh new file mode 100644 index 00000000000000..7726c5b097c5c8 --- /dev/null +++ b/legacy/generate_checkpoints.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Sample call: +# bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4 + +LEGACY_PATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" + +echo $LEGACY_PATH +# install some PT version here so it does not need to reinstalled for each env +pip install virtualenv "torch==1.5" --quiet --no-cache-dir + +ENV_PATH="$LEGACY_PATH/vEnv" + +# iterate over all arguments assuming that each argument is version +for ver in "$@" +do + echo "processing version: $ver" + # mkdir "$LEGACY_PATH/$ver" + + # create local env + echo $ENV_PATH + virtualenv $ENV_PATH --system-site-packages + # activate and install PL version + source "$ENV_PATH/bin/activate" + # there are problem to load ckpt in older versions since they are saved the newer versions + pip install "pytorch_lightning==$ver" "torch==1.3" --quiet --no-cache-dir + + python --version + pip --version + pip list | grep torch + + python "$LEGACY_PATH/zero_training.py" + cp "$LEGACY_PATH/zero_training.py" ${LEGACY_PATH}/checkpoints/${ver} + + mv ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs/version_0/checkpoints/*.ckpt ${LEGACY_PATH}/checkpoints/${ver}/ + rm -rf ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs + + deactivate + # clear env + rm -rf $ENV_PATH + +done diff --git a/legacy/zero_training.py b/legacy/zero_training.py new file mode 100644 index 00000000000000..044fa9bc2b971a --- /dev/null +++ b/legacy/zero_training.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from torch.utils.data import Dataset + +import pytorch_lightning as pl + +PATH_LEGACY = os.path.dirname(__file__) + + +class RandomDataset(Dataset): + + def __init__(self, size, length: int = 100): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class DummyModel(pl.LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def _loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def _step(self, batch, batch_idx): + output = self.layer(batch) + loss = self._loss(batch, output) + # return {'loss': loss} # used for PL<1.0 + return loss # used for PL >= 1.0 + + def training_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self._step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + self._step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +def main_train(dir_path, max_epochs: int = 5): + + trainer = pl.Trainer( + default_root_dir=dir_path, + checkpoint_callback=True, + max_epochs=max_epochs, + ) + + model = DummyModel() + trainer.fit(model) + + +if __name__ == '__main__': + path_dir = os.path.join(PATH_LEGACY, 'checkpoints', str(pl.__version__)) + main_train(path_dir) diff --git a/notebooks/01-mnist-hello-world.ipynb b/notebooks/01-mnist-hello-world.ipynb new file mode 100644 index 00000000000000..cdcc6cd28a486d --- /dev/null +++ b/notebooks/01-mnist-hello-world.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "i7XbLCXGkll9" + }, + "source": [ + "# Introduction to Pytorch Lightning ⚡\n", + "\n", + "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2LODD6w9ixlT" + }, + "source": [ + "### Setup \n", + "Lightning is easy to install. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "zK7-Gg69kMnG" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "w4_TYnt_keJi" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import MNIST\n", + "from torchvision import transforms\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EHpyMPKFkVbZ" + }, + "source": [ + "## Simplest example\n", + "\n", + "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", + "\n", + "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "V7ELesz1kVQo" + }, + "outputs": [], + "source": [ + "class MNISTModel(pl.LightningModule):\n", + "\n", + " def __init__(self):\n", + " super(MNISTModel, self).__init__()\n", + " self.l1 = torch.nn.Linear(28 * 28, 10)\n", + "\n", + " def forward(self, x):\n", + " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", + "\n", + " def training_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=0.02)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hIrtHg-Dv8TJ" + }, + "source": [ + "By using the `Trainer` you automatically get:\n", + "1. Tensorboard logging\n", + "2. Model checkpointing\n", + "3. Training and validation loop\n", + "4. early-stopping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4Dk6Ykv8lI7X" + }, + "outputs": [], + "source": [ + "# Init our model\n", + "mnist_model = MNISTModel()\n", + "\n", + "# Init DataLoader from MNIST Dataset\n", + "train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", + "train_loader = DataLoader(train_ds, batch_size=32)\n", + "\n", + "# Initialize a trainer\n", + "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", + "\n", + "# Train the model ⚡\n", + "trainer.fit(mnist_model, train_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KNpOoBeIjscS" + }, + "source": [ + "## A more complete MNIST Lightning Module Example\n", + "\n", + "That wasn't so hard was it?\n", + "\n", + "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", + "\n", + "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", + "\n", + "---\n", + "\n", + "### Note what the following built-in functions are doing:\n", + "\n", + "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n", + " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", + " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", + "\n", + "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#setup) ⚙️\n", + " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", + " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", + " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", + " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", + "\n", + "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#data-hooks) ♻️\n", + " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4DNItffri95Q" + }, + "outputs": [], + "source": [ + "class LitMNIST(pl.LightningModule):\n", + " \n", + " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " # Set our init args as class attributes\n", + " self.data_dir = data_dir\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " # Hardcode some dataset specific attributes\n", + " self.num_classes = 10\n", + " self.dims = (1, 28, 28)\n", + " channels, width, height = self.dims\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # Define PyTorch model\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, self.num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + "\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " # Here we just reuse the validation_step for testing\n", + " return self.validation_step(batch, batch_idx)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n", + "\n", + " ####################\n", + " # DATA RELATED HOOKS\n", + " ####################\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Mb0U5Rk2kLBy" + }, + "outputs": [], + "source": [ + "model = LitMNIST()\n", + "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nht8AvMptY6I" + }, + "source": [ + "### Testing\n", + "\n", + "To test a model, call `trainer.test(model)`.\n", + "\n", + "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PA151FkLtprO" + }, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "T3-3lbbNtr5T" + }, + "source": [ + "### Bonus Tip\n", + "\n", + "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IFBwCbLet2r6" + }, + "outputs": [], + "source": [ + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8TRyS5CCt3n9" + }, + "source": [ + "In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wizS-QiLuAYo" + }, + "outputs": [], + "source": [ + "# Start tensorboard.\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir lightning_logs/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ", + "collapsed_sections": [], + "include_colab_link": true, + "name": "01-mnist-hello-world.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/02-datamodules.ipynb b/notebooks/02-datamodules.ipynb new file mode 100644 index 00000000000000..5438cd5dc5c2f4 --- /dev/null +++ b/notebooks/02-datamodules.ipynb @@ -0,0 +1,588 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2O5r7QvP8-rt" + }, + "source": [ + "# PyTorch Lightning DataModules ⚡\n", + "\n", + "With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n", + "\n", + "This notebook will walk you through how to start using Datamodules.\n", + "\n", + "The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html).\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6RYMhmfA9ATN" + }, + "source": [ + "### Setup\n", + "Lightning is easy to install. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lj2zD-wsbvGr" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8g2mbvy-9xDI" + }, + "source": [ + "# Introduction\n", + "\n", + "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "eg-xDlmDdAwy" + }, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import random_split, DataLoader\n", + "\n", + "# Note - you must have torchvision installed for this example\n", + "from torchvision.datasets import MNIST, CIFAR10\n", + "from torchvision import transforms" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DzgY7wi88UuG" + }, + "source": [ + "## Defining the LitMNISTModel\n", + "\n", + "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", + "\n", + "Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n", + "\n", + "This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IQkW8_FF5nU2" + }, + "outputs": [], + "source": [ + "class LitMNIST(pl.LightningModule):\n", + " \n", + " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " # We hardcode dataset specific stuff here.\n", + " self.data_dir = data_dir\n", + " self.num_classes = 10\n", + " self.dims = (1, 28, 28)\n", + " channels, width, height = self.dims\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " # Build model\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, self.num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n", + "\n", + " ####################\n", + " # DATA RELATED HOOKS\n", + " ####################\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K7sg9KQd-QIO" + }, + "source": [ + "## Training the ListMNIST Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "QxDNDaus6byD" + }, + "outputs": [], + "source": [ + "model = LitMNIST()\n", + "trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "dY8d6GxmB0YU" + }, + "source": [ + "# Using DataModules\n", + "\n", + "DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eJeT5bW081wn" + }, + "source": [ + "## Defining The MNISTDataModule\n", + "\n", + "Let's go over each function in the class below and talk about what they're doing:\n", + "\n", + "1. ```__init__```\n", + " - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n", + " - Defines a transform that will be applied across train, val, and test dataset splits.\n", + " - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n", + "\n", + "\n", + "2. ```prepare_data```\n", + " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", + " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", + "\n", + "3. ```setup```\n", + " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", + " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", + " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n", + " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", + "\n", + "\n", + "4. ```x_dataloader```\n", + " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DfGKyGwG_X9v" + }, + "outputs": [], + "source": [ + "class MNISTDataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './'):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # self.dims is returned when you call dm.size()\n", + " # Setting default dims here because we know them.\n", + " # Could optionally be assigned dynamically in dm.setup()\n", + " self.dims = (1, 28, 28)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "H2Yoj-9M9dS7" + }, + "source": [ + "## Defining the dataset agnostic `LitModel`\n", + "\n", + "Below, we define the same model as the `LitMNIST` model we made earlier. \n", + "\n", + "However, this time our model has the freedom to use any input data that we'd like 🔥." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PM2IISuOBDIu" + }, + "outputs": [], + "source": [ + "class LitModel(pl.LightningModule):\n", + " \n", + " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " # We take in input dimensions as parameters and use those to dynamically build model.\n", + " self.channels = channels\n", + " self.width = width\n", + " self.height = height\n", + " self.num_classes = num_classes\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + "\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "G4Z5olPe-xEo" + }, + "source": [ + "## Training the `LitModel` using the `MNISTDataModule`\n", + "\n", + "Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kV48vP_9mEli" + }, + "outputs": [], + "source": [ + "# Init DataModule\n", + "dm = MNISTDataModule()\n", + "# Init model from datamodule's attributes\n", + "model = LitModel(*dm.size(), dm.num_classes)\n", + "# Init trainer\n", + "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n", + "# Pass the datamodule as arg to trainer.fit to override model hooks :)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WNxrugIGRRv5" + }, + "source": [ + "## Defining the CIFAR10 DataModule\n", + "\n", + "Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1tkaYLU7RT5P" + }, + "outputs": [], + "source": [ + "class CIFAR10DataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './'):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + " ])\n", + "\n", + " self.dims = (3, 32, 32)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " CIFAR10(self.data_dir, train=True, download=True)\n", + " CIFAR10(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", + " self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.cifar_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.cifar_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.cifar_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BrXxf3oX_gsZ" + }, + "source": [ + "## Training the `LitModel` using the `CIFAR10DataModule`\n", + "\n", + "Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n", + "\n", + "The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "sd-SbWi_krdj" + }, + "outputs": [], + "source": [ + "dm = CIFAR10DataModule()\n", + "model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n", + "trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "include_colab_link": true, + "name": "02-datamodules.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/03-basic-gan.ipynb b/notebooks/03-basic-gan.ipynb new file mode 100644 index 00000000000000..5cee735842a085 --- /dev/null +++ b/notebooks/03-basic-gan.ipynb @@ -0,0 +1,472 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "J37PBnE_x7IW" + }, + "source": [ + "# PyTorch Lightning Basic GAN Tutorial ⚡\n", + "\n", + "How to train a GAN!\n", + "\n", + "Main takeaways:\n", + "1. Generator and discriminator are arbitrary PyTorch modules.\n", + "2. training_step does both the generator and discriminator training.\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kg2MKpRmybht" + }, + "source": [ + "### Setup\n", + "Lightning is easy to install. Simply `pip install pytorch-lightning`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LfrJLKPFyhsK" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "BjEPuiVLyanw" + }, + "outputs": [], + "source": [ + "import os\n", + "from argparse import ArgumentParser\n", + "from collections import OrderedDict\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import MNIST\n", + "\n", + "import pytorch_lightning as pl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OuXJzr4G2uHV" + }, + "source": [ + "### MNIST DataModule\n", + "\n", + "Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DOY_nHu328g7" + }, + "outputs": [], + "source": [ + "class MNISTDataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + "\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # self.dims is returned when you call dm.size()\n", + " # Setting default dims here because we know them.\n", + " # Could optionally be assigned dynamically in dm.setup()\n", + " self.dims = (1, 28, 28)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tW3c0QrQyF9P" + }, + "source": [ + "### A. Generator" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0E2QDjl5yWtz" + }, + "outputs": [], + "source": [ + "class Generator(nn.Module):\n", + " def __init__(self, latent_dim, img_shape):\n", + " super().__init__()\n", + " self.img_shape = img_shape\n", + "\n", + " def block(in_feat, out_feat, normalize=True):\n", + " layers = [nn.Linear(in_feat, out_feat)]\n", + " if normalize:\n", + " layers.append(nn.BatchNorm1d(out_feat, 0.8))\n", + " layers.append(nn.LeakyReLU(0.2, inplace=True))\n", + " return layers\n", + "\n", + " self.model = nn.Sequential(\n", + " *block(latent_dim, 128, normalize=False),\n", + " *block(128, 256),\n", + " *block(256, 512),\n", + " *block(512, 1024),\n", + " nn.Linear(1024, int(np.prod(img_shape))),\n", + " nn.Tanh()\n", + " )\n", + "\n", + " def forward(self, z):\n", + " img = self.model(z)\n", + " img = img.view(img.size(0), *self.img_shape)\n", + " return img" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uyrltsGvyaI3" + }, + "source": [ + "### B. Discriminator" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ed3MR3vnyxyW" + }, + "outputs": [], + "source": [ + "class Discriminator(nn.Module):\n", + " def __init__(self, img_shape):\n", + " super().__init__()\n", + "\n", + " self.model = nn.Sequential(\n", + " nn.Linear(int(np.prod(img_shape)), 512),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Linear(512, 256),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Linear(256, 1),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " def forward(self, img):\n", + " img_flat = img.view(img.size(0), -1)\n", + " validity = self.model(img_flat)\n", + "\n", + " return validity" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BwUMom3ryySK" + }, + "source": [ + "### C. GAN\n", + "\n", + "#### A couple of cool features to check out in this example...\n", + "\n", + " - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n", + " - Lightning will put your dataloader data on the right device automatically\n", + " - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n", + " - `type_as` is the way we recommend to do this.\n", + " - This example shows how to use multiple dataloaders in your `LightningModule`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3vKszYf6y1Vv" + }, + "outputs": [], + "source": [ + " class GAN(pl.LightningModule):\n", + "\n", + " def __init__(\n", + " self,\n", + " channels,\n", + " width,\n", + " height,\n", + " latent_dim: int = 100,\n", + " lr: float = 0.0002,\n", + " b1: float = 0.5,\n", + " b2: float = 0.999,\n", + " batch_size: int = 64,\n", + " **kwargs\n", + " ):\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + "\n", + " # networks\n", + " data_shape = (channels, width, height)\n", + " self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n", + " self.discriminator = Discriminator(img_shape=data_shape)\n", + "\n", + " self.validation_z = torch.randn(8, self.hparams.latent_dim)\n", + "\n", + " self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n", + "\n", + " def forward(self, z):\n", + " return self.generator(z)\n", + "\n", + " def adversarial_loss(self, y_hat, y):\n", + " return F.binary_cross_entropy(y_hat, y)\n", + "\n", + " def training_step(self, batch, batch_idx, optimizer_idx):\n", + " imgs, _ = batch\n", + "\n", + " # sample noise\n", + " z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n", + " z = z.type_as(imgs)\n", + "\n", + " # train generator\n", + " if optimizer_idx == 0:\n", + "\n", + " # generate images\n", + " self.generated_imgs = self(z)\n", + "\n", + " # log sampled images\n", + " sample_imgs = self.generated_imgs[:6]\n", + " grid = torchvision.utils.make_grid(sample_imgs)\n", + " self.logger.experiment.add_image('generated_images', grid, 0)\n", + "\n", + " # ground truth result (ie: all fake)\n", + " # put on GPU because we created this tensor inside training_loop\n", + " valid = torch.ones(imgs.size(0), 1)\n", + " valid = valid.type_as(imgs)\n", + "\n", + " # adversarial loss is binary cross-entropy\n", + " g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n", + " tqdm_dict = {'g_loss': g_loss}\n", + " output = OrderedDict({\n", + " 'loss': g_loss,\n", + " 'progress_bar': tqdm_dict,\n", + " 'log': tqdm_dict\n", + " })\n", + " return output\n", + "\n", + " # train discriminator\n", + " if optimizer_idx == 1:\n", + " # Measure discriminator's ability to classify real from generated samples\n", + "\n", + " # how well can it label as real?\n", + " valid = torch.ones(imgs.size(0), 1)\n", + " valid = valid.type_as(imgs)\n", + "\n", + " real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n", + "\n", + " # how well can it label as fake?\n", + " fake = torch.zeros(imgs.size(0), 1)\n", + " fake = fake.type_as(imgs)\n", + "\n", + " fake_loss = self.adversarial_loss(\n", + " self.discriminator(self(z).detach()), fake)\n", + "\n", + " # discriminator loss is the average of these\n", + " d_loss = (real_loss + fake_loss) / 2\n", + " tqdm_dict = {'d_loss': d_loss}\n", + " output = OrderedDict({\n", + " 'loss': d_loss,\n", + " 'progress_bar': tqdm_dict,\n", + " 'log': tqdm_dict\n", + " })\n", + " return output\n", + "\n", + " def configure_optimizers(self):\n", + " lr = self.hparams.lr\n", + " b1 = self.hparams.b1\n", + " b2 = self.hparams.b2\n", + "\n", + " opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n", + " opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n", + " return [opt_g, opt_d], []\n", + "\n", + " def on_epoch_end(self):\n", + " z = self.validation_z.type_as(self.generator.model[0].weight)\n", + "\n", + " # log sampled images\n", + " sample_imgs = self(z)\n", + " grid = torchvision.utils.make_grid(sample_imgs)\n", + " self.logger.experiment.add_image('generated_images', grid, self.current_epoch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ey5FmJPnzm_E" + }, + "outputs": [], + "source": [ + "dm = MNISTDataModule()\n", + "model = GAN(*dm.size())\n", + "trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MlECc7cHzolp" + }, + "outputs": [], + "source": [ + "# Start tensorboard.\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir lightning_logs/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "include_colab_link": true, + "name": "03-basic-gan.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb new file mode 100644 index 00000000000000..957255969f6088 --- /dev/null +++ b/notebooks/04-transformers-text-classification.ipynb @@ -0,0 +1,599 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8ag5ANQPJ_j9" + }, + "source": [ + "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n", + "\n", + "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n", + "\n", + "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Ask a question on [GitHub Discussions](https://github.com/PyTorchLightning/pytorch-lightning/discussions/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", + "\n", + " - [HuggingFace datasets](https://github.com/huggingface/datasets)\n", + " - [HuggingFace transformers](https://github.com/huggingface/transformers)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fqlsVTj7McZ3" + }, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OIhHrRL-MnKK" + }, + "outputs": [], + "source": [ + "!pip install pytorch-lightning datasets transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "6yuQT_ZQMpCg" + }, + "outputs": [], + "source": [ + "from argparse import ArgumentParser\n", + "from datetime import datetime\n", + "from typing import Optional\n", + "\n", + "import datasets\n", + "import numpy as np\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from transformers import (\n", + " AdamW,\n", + " AutoModelForSequenceClassification,\n", + " AutoConfig,\n", + " AutoTokenizer,\n", + " get_linear_schedule_with_warmup,\n", + " glue_compute_metrics\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9ORJfiuiNZ_N" + }, + "source": [ + "## GLUE DataModule" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jW9xQhZxMz1G" + }, + "outputs": [], + "source": [ + "class GLUEDataModule(pl.LightningDataModule):\n", + "\n", + " task_text_field_map = {\n", + " 'cola': ['sentence'],\n", + " 'sst2': ['sentence'],\n", + " 'mrpc': ['sentence1', 'sentence2'],\n", + " 'qqp': ['question1', 'question2'],\n", + " 'stsb': ['sentence1', 'sentence2'],\n", + " 'mnli': ['premise', 'hypothesis'],\n", + " 'qnli': ['question', 'sentence'],\n", + " 'rte': ['sentence1', 'sentence2'],\n", + " 'wnli': ['sentence1', 'sentence2'],\n", + " 'ax': ['premise', 'hypothesis']\n", + " }\n", + "\n", + " glue_task_num_labels = {\n", + " 'cola': 2,\n", + " 'sst2': 2,\n", + " 'mrpc': 2,\n", + " 'qqp': 2,\n", + " 'stsb': 1,\n", + " 'mnli': 3,\n", + " 'qnli': 2,\n", + " 'rte': 2,\n", + " 'wnli': 2,\n", + " 'ax': 3\n", + " }\n", + "\n", + " loader_columns = [\n", + " 'datasets_idx',\n", + " 'input_ids',\n", + " 'token_type_ids',\n", + " 'attention_mask',\n", + " 'start_positions',\n", + " 'end_positions',\n", + " 'labels'\n", + " ]\n", + "\n", + " def __init__(\n", + " self,\n", + " model_name_or_path: str,\n", + " task_name: str ='mrpc',\n", + " max_seq_length: int = 128,\n", + " train_batch_size: int = 32,\n", + " eval_batch_size: int = 32,\n", + " **kwargs\n", + " ):\n", + " super().__init__()\n", + " self.model_name_or_path = model_name_or_path\n", + " self.task_name = task_name\n", + " self.max_seq_length = max_seq_length\n", + " self.train_batch_size = train_batch_size\n", + " self.eval_batch_size = eval_batch_size\n", + "\n", + " self.text_fields = self.task_text_field_map[task_name]\n", + " self.num_labels = self.glue_task_num_labels[task_name]\n", + " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", + "\n", + " def setup(self, stage):\n", + " self.dataset = datasets.load_dataset('glue', self.task_name)\n", + "\n", + " for split in self.dataset.keys():\n", + " self.dataset[split] = self.dataset[split].map(\n", + " self.convert_to_features,\n", + " batched=True,\n", + " remove_columns=['label'],\n", + " )\n", + " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n", + " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n", + "\n", + " self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n", + "\n", + " def prepare_data(self):\n", + " datasets.load_dataset('glue', self.task_name)\n", + " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", + " \n", + " def train_dataloader(self):\n", + " return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n", + " \n", + " def val_dataloader(self):\n", + " if len(self.eval_splits) == 1:\n", + " return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n", + " elif len(self.eval_splits) > 1:\n", + " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", + "\n", + " def test_dataloader(self):\n", + " if len(self.eval_splits) == 1:\n", + " return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n", + " elif len(self.eval_splits) > 1:\n", + " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", + "\n", + " def convert_to_features(self, example_batch, indices=None):\n", + "\n", + " # Either encode single sentence or sentence pairs\n", + " if len(self.text_fields) > 1:\n", + " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n", + " else:\n", + " texts_or_text_pairs = example_batch[self.text_fields[0]]\n", + "\n", + " # Tokenize the text/text pairs\n", + " features = self.tokenizer.batch_encode_plus(\n", + " texts_or_text_pairs,\n", + " max_length=self.max_seq_length,\n", + " pad_to_max_length=True,\n", + " truncation=True\n", + " )\n", + "\n", + " # Rename label to labels to make it easier to pass to model forward\n", + " features['labels'] = example_batch['label']\n", + "\n", + " return features" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jQC3a6KuOpX3" + }, + "source": [ + "#### You could use this datamodule with standalone PyTorch if you wanted..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JCMH3IAsNffF" + }, + "outputs": [], + "source": [ + "dm = GLUEDataModule('distilbert-base-uncased')\n", + "dm.prepare_data()\n", + "dm.setup('fit')\n", + "next(iter(dm.train_dataloader()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "l9fQ_67BO2Lj" + }, + "source": [ + "## GLUE Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gtn5YGKYO65B" + }, + "outputs": [], + "source": [ + "class GLUETransformer(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " model_name_or_path: str,\n", + " num_labels: int,\n", + " learning_rate: float = 2e-5,\n", + " adam_epsilon: float = 1e-8,\n", + " warmup_steps: int = 0,\n", + " weight_decay: float = 0.0,\n", + " train_batch_size: int = 32,\n", + " eval_batch_size: int = 32,\n", + " eval_splits: Optional[list] = None,\n", + " **kwargs\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters()\n", + "\n", + " self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n", + " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n", + " self.metric = datasets.load_metric(\n", + " 'glue',\n", + " self.hparams.task_name,\n", + " experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n", + " )\n", + "\n", + " def forward(self, **inputs):\n", + " return self.model(**inputs)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " outputs = self(**batch)\n", + " loss = outputs[0]\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n", + " outputs = self(**batch)\n", + " val_loss, logits = outputs[:2]\n", + "\n", + " if self.hparams.num_labels >= 1:\n", + " preds = torch.argmax(logits, axis=1)\n", + " elif self.hparams.num_labels == 1:\n", + " preds = logits.squeeze()\n", + "\n", + " labels = batch[\"labels\"]\n", + "\n", + " return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n", + "\n", + " def validation_epoch_end(self, outputs):\n", + " if self.hparams.task_name == 'mnli':\n", + " for i, output in enumerate(outputs):\n", + " # matched or mismatched\n", + " split = self.hparams.eval_splits[i].split('_')[-1]\n", + " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n", + " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n", + " loss = torch.stack([x['loss'] for x in output]).mean()\n", + " self.log(f'val_loss_{split}', loss, prog_bar=True)\n", + " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n", + " self.log_dict(split_metrics, prog_bar=True)\n", + " return loss\n", + "\n", + " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n", + " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n", + " loss = torch.stack([x['loss'] for x in outputs]).mean()\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", + " return loss\n", + "\n", + " def setup(self, stage):\n", + " if stage == 'fit':\n", + " # Get dataloader by calling it - train_dataloader() is called after setup() by default\n", + " train_loader = self.train_dataloader()\n", + "\n", + " # Calculate total steps\n", + " self.total_steps = (\n", + " (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n", + " // self.hparams.accumulate_grad_batches\n", + " * float(self.hparams.max_epochs)\n", + " )\n", + "\n", + " def configure_optimizers(self):\n", + " \"Prepare optimizer and schedule (linear warmup and decay)\"\n", + " model = self.model\n", + " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", + " optimizer_grouped_parameters = [\n", + " {\n", + " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": self.hparams.weight_decay,\n", + " },\n", + " {\n", + " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": 0.0,\n", + " },\n", + " ]\n", + " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n", + "\n", + " scheduler = get_linear_schedule_with_warmup(\n", + " optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n", + " )\n", + " scheduler = {\n", + " 'scheduler': scheduler,\n", + " 'interval': 'step',\n", + " 'frequency': 1\n", + " }\n", + " return [optimizer], [scheduler]\n", + "\n", + " @staticmethod\n", + " def add_model_specific_args(parent_parser):\n", + " parser = parent_parser.add_argument_group(\"GLUETransformer\")", + " parser = ArgumentParser(parents=[parent_parser], add_help=False)\n", + " parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n", + " parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n", + " parser.add_argument(\"--warmup_steps\", default=0, type=int)\n", + " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n", + " return parent_parser" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ha-NdIP_xbd3" + }, + "source": [ + "### ⚡ Quick Tip \n", + " - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3dEHnl3RPlAR" + }, + "outputs": [], + "source": [ + "def parse_args(args=None):\n", + " parser = ArgumentParser()\n", + " parser = pl.Trainer.add_argparse_args(parser)\n", + " parser = GLUEDataModule.add_argparse_args(parser)\n", + " parser = GLUETransformer.add_model_specific_args(parser)\n", + " parser.add_argument('--seed', type=int, default=42)\n", + " return parser.parse_args(args)\n", + "\n", + "\n", + "def main(args):\n", + " pl.seed_everything(args.seed)\n", + " dm = GLUEDataModule.from_argparse_args(args)\n", + " dm.prepare_data()\n", + " dm.setup('fit')\n", + " model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n", + " trainer = pl.Trainer.from_argparse_args(args)\n", + " return dm, model, trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PkuLaeec3sJ-" + }, + "source": [ + "# Training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QSpueK5UPsN7" + }, + "source": [ + "## CoLA\n", + "\n", + "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NJnFmtpnPu0Y" + }, + "outputs": [], + "source": [ + "mocked_args = \"\"\"\n", + " --model_name_or_path albert-base-v2\n", + " --task_name cola\n", + " --max_epochs 3\n", + " --gpus 1\"\"\".split()\n", + "\n", + "args = parse_args(mocked_args)\n", + "dm, model, trainer = main(args)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_MrNsTnqdz4z" + }, + "source": [ + "## MRPC\n", + "\n", + "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LBwRxg9Cb3d-" + }, + "outputs": [], + "source": [ + "mocked_args = \"\"\"\n", + " --model_name_or_path distilbert-base-cased\n", + " --task_name mrpc\n", + " --max_epochs 3\n", + " --gpus 1\"\"\".split()\n", + "\n", + "args = parse_args(mocked_args)\n", + "dm, model, trainer = main(args)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iZhbn0HzfdCu" + }, + "source": [ + "## MNLI\n", + "\n", + " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n", + "\n", + " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n", + "\n", + "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AvsZMOggfcWW" + }, + "outputs": [], + "source": [ + "mocked_args = \"\"\"\n", + " --model_name_or_path distilbert-base-uncased\n", + " --task_name mnli\n", + " --max_epochs 1\n", + " --gpus 1\n", + " --limit_train_batches 10\n", + " --progress_bar_refresh_rate 20\"\"\".split()\n", + "\n", + "args = parse_args(mocked_args)\n", + "dm, model, trainer = main(args)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "04-transformers-text-classification.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/05-trainer-flags-overview.ipynb b/notebooks/05-trainer-flags-overview.ipynb new file mode 100644 index 00000000000000..eb604be5e15597 --- /dev/null +++ b/notebooks/05-trainer-flags-overview.ipynb @@ -0,0 +1,2926 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "goRmGIRI5cfC" + }, + "source": [ + "# Introduction to Lightning Flags ⚡🚩\n", + "\n", + "In this notebook, we'll go over the flags available in the `Trainer` object. Note that not everything will work in the Colab environment (multi-gpu, etc). This notebook accompanies the Trainer videos we'll be putting out.\n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jKj5lgdr5j48" + }, + "source": [ + "--- \n", + "### Setup \n", + "First thing first, we need to install Lightning. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UGjilEHk4vb7" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zaVUShmQ5n8Y" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from argparse import ArgumentParser\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.data import random_split\n", + "from torchvision.datasets import MNIST\n", + "from torchvision import transforms\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy\n", + "\n", + "from torchvision.datasets.mnist import MNIST\n", + "from torchvision import transforms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6tgkS8IYZwY_" + }, + "outputs": [], + "source": [ + "# ------------\n", + "# data\n", + "# ------------\n", + "pl.seed_everything(1234)\n", + "batch_size = 32\n", + "\n", + "# Init DataLoader from MNIST Dataset\n", + "\n", + "dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", + "mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())\n", + "mnist_train, mnist_val = random_split(dataset, [55000, 5000])\n", + "\n", + "train_loader = DataLoader(mnist_train, batch_size=batch_size)\n", + "val_loader = DataLoader(mnist_val, batch_size=batch_size)\n", + "test_loader = DataLoader(mnist_test, batch_size=batch_size)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gEulmrbxwaYL" + }, + "source": [ + "### Simple AutoEncoder Model\n", + "\n", + "Were gonna define a simple Lightning model so we can play with all the settings of the Lightning Trainer.\n", + "\n", + "LightningModule is simply pure Pytorch reorganized into hooks, that represents all the steps in the training process.\n", + "\n", + "You can use LightningModule hooks to control every part of your model, but for the purpose of this video we will use a very simple MNIST classifier, a model that takes 28*28 grayscale images of hand written images, and can predict the digit between 0-9.\n", + "\n", + "The LightningModule can encompass a single model, like an image classifier, or a deep learning system composed of multiple models, like this auto encoder that contains an encoder and a decoder.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x-34xKCI40yW" + }, + "outputs": [], + "source": [ + "class LitAutoEncoder(pl.LightningModule):\n", + "\n", + " def __init__(self, batch_size=32, lr=1e-3):\n", + " super().__init__()\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(28 * 28, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 3)\n", + " )\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(3, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 28 * 28)\n", + " )\n", + " self.batch_size=batch_size\n", + " self.learning_rate=lr\n", + "\n", + " def forward(self, x):\n", + " # in lightning, forward defines the prediction/inference actions\n", + " embedding = self.encoder(x)\n", + " return embedding\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " x = x.view(x.size(0), -1)\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log('train_loss', loss)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " x = x.view(x.size(0), -1)\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log('val_loss', loss)\n", + " \n", + " def test_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " x = x.view(x.size(0), -1)\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log('test_loss', loss)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VbxcRCrxiYly" + }, + "source": [ + "You'll notice the LightningModule doesn't have epoch and batch loops, we're not calling model.train() and model.eval(), and no mentions of CUDA or hardware. That's because it is all automated by the Lightning Trainer. All the engineering boilerplate is automated by the trainer: \n", + "\n", + "* Training loops\n", + "* Evaluation and test loops\n", + "* Calling model.train(), model.eval(), no_grad at the right time\n", + "* CUDA or to_device calls\n", + "\n", + "It also allows you to train your models on different hardware like GPUs and TPUs without changing your code!\n", + "\n", + "\n", + "### To use the lightning trainer simply:\n", + "\n", + "1. init your LightningModule and datasets\n", + "\n", + "2. init lightning trainer\n", + "\n", + "3. call trainer.fit\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HOk9c4_35FKg" + }, + "outputs": [], + "source": [ + "#####################\n", + "# 1. Init Model\n", + "#####################\n", + "\n", + "model = LitAutoEncoder()\n", + "\n", + "#####################\n", + "# 2. Init Trainer\n", + "#####################\n", + "\n", + "# these 2 flags are explained in the later sections...but for short explanation:\n", + "# - progress_bar_refresh_rate: limits refresh rate of tqdm progress bar so Colab doesn't freak out\n", + "# - max_epochs: only run 2 epochs instead of default of 1000\n", + "trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=2)\n", + "\n", + "#####################\n", + "# 3. Train\n", + "#####################\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3meDako-Qa_6" + }, + "source": [ + "Our model is training just like that, using the Lightning defaults. The beauty of Lightning is that everything is easily configurable.\n", + "In our next videos were going to show you all the ways you can control your Trainer to do things like controlling your training, validation and test loops, running on GPUs and TPUs, checkpointing, early stopping, and a lot more.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z_Wry2MckQkI" + }, + "source": [ + "# Training loop and eval loop Flags" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0MkI1xB2vsLj" + }, + "source": [ + "\n", + "To really scale up your networks, you can use accelerators like GPUs. GPUs or Graphical Processing Units, parallelize matrix multiplications which enable speed ups of at least 100x over training on CPUs.\n", + "\n", + "Let's say you have a machine with 8 GPUs on it. You can set this flag to 1, 4, or 8 GPUs and lightning will automatically distribute your training for you.\n", + "\n", + "```\n", + "trainer = pl.Trainer(gpus=1)\n", + "```\n", + "\n", + "---------\n", + "\n", + "Lightning makes your code hardware agnostic... This means, you can switch between CPUs, GPUs without code changes.\n", + "\n", + "However, it requires forming good PyTorch habits:\n", + "\n", + "1. First, remove the .cuda() or .to() calls in your code.\n", + "2. Second, when you initialize a new tensor, set the device=self.device in the call since every lightningModule knows what gpu index or TPU core it is on.\n", + "\n", + "You can also use type_as and or you can register the tensor as a buffer in your module’s __init__ method with register_buffer().\n", + "\n", + "```\n", + "# before lightning\n", + "def forward(self, x):\n", + " z = torch.Tensor(2, 3)\n", + " z = z.cuda(0)\n", + "\n", + "# with lightning\n", + "def forward(self, x):\n", + " z = torch.Tensor(2, 3)\n", + " z = z.type_as(x, device=self.device)\n", + "```\n", + "\n", + "\n", + "```\n", + "class LitModel(LightningModule):\n", + "\n", + " def __init__(self):\n", + " ...\n", + " self.register_buffer(\"sigma\", torch.eye(3))\n", + " # you can now access self.sigma anywhere in your module\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hw6jJhhjvlSL" + }, + "source": [ + "Lightning Trainer automates all the engineering boilerplate like iterating over epochs and batches, training eval and test loops, CUDA and to(device) calls, calling model.train and model.eval.\n", + "\n", + "You still have full control over the loops, by using the following trainer flags:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pT5-ETH9eUg6" + }, + "source": [ + "## Calling validation steps\n", + "Sometimes, training an epoch may be pretty fast, like minutes per epoch. In this case, you might not need to validate on every epoch. Instead, you can actually validate after a few epochs.\n", + "\n", + "Use `check_val_every_n_epoch` flag to control the frequency of validation step:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z-EMVvKheu3D" + }, + "outputs": [], + "source": [ + "# run val loop every 10 training epochs\n", + "trainer = pl.Trainer(check_val_every_n_epoch=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UOzZr9S2UcSO" + }, + "source": [ + "## val_check_interval\n", + "\n", + "In some cases where your epoch is very long, you might want to check validation within an epoch.\n", + "\n", + "You can also run validation step within your training epochs, by setting `val_check_interval` flag.\n", + "\n", + "Set `val_check_interval` to a float between [0.0 to 1.0] to check your validation set within a training epoch. For example, setting it to 0.25 will check your validation set 4 times during a training epoch.\n", + "\n", + "Default is set to 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9kbUbvrUVLrT" + }, + "outputs": [], + "source": [ + "# check validation set 4 times during a training epoch\n", + "trainer = pl.Trainer(val_check_interval=0.25)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Onm1gBsKVaw4" + }, + "source": [ + "When you have iterable data sets, or when streaming data for production use cases, it is useful to check the validation set every number of steps. \n", + "Set val_check_interval to an int:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "psn6DVb5Vi85" + }, + "outputs": [], + "source": [ + "# check validation set every 1000 training batches\n", + "# use this when using iterableDataset and your dataset has no length\n", + "# (ie: production cases with streaming data)\n", + "trainer = pl.Trainer(val_check_interval=1000)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QkoYonrWkb7-" + }, + "source": [ + "## num_sanity_val_steps \n", + "\n", + "You may have run into an issue, where you have a bug in your validation loop, but won't catch it until your training loop ends.\n", + "\n", + "and if your training loop takes hours or days, you will waste valuable compute.\n", + "\n", + "Instead, lightning automatically runs through 2 steps of validation in the beginning to catch these kinds of bugs up front.\n", + "\n", + "\n", + "The `num_sanity_val_steps` flag can help you run n batches of validation before starting the training routine.\n", + "\n", + "You can set it to 0 to turn it off" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zOcT-ugSkiKW" + }, + "outputs": [], + "source": [ + "# turn it off\n", + "trainer = pl.Trainer(num_sanity_val_steps=0)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zS0ob1ZmTw56" + }, + "source": [ + "Set it to -1 to check all validation data before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rzqvjA4UT263" + }, + "outputs": [], + "source": [ + "# check all validation data\n", + "trainer = pl.Trainer(num_sanity_val_steps=-1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uMB41wq4T3Z2" + }, + "source": [ + "Or use any arbitrary number of validation steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lGP78aQzT7VS" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(num_sanity_val_steps=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H-xaYRtd1rb-" + }, + "source": [ + "## Limit train, validation, and test batches\n", + "\n", + "You can set limits on how much of training, validation and test dataset you want your model to check. This is useful if you have really large validation or tests sets, for debugging or testing something that happens at the end of an epoch.\n", + "\n", + "Set the flag to int to specify the number of batches to run\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XiK5cFKL1rcA" + }, + "outputs": [], + "source": [ + "# run for only 10 batches\n", + "trainer = pl.Trainer(limit_test_batches=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y4LK0g65RrBm" + }, + "source": [ + "For example, some metrics need to be computed on the entire validation results, such as AUC ROC. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8MmeRs2DR3dD" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(limit_val_batches=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xmigcNa1A2Vy" + }, + "source": [ + "You can use a float to limit the batches be percentage of the set on every epoch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W7uGJt8nA4tv" + }, + "outputs": [], + "source": [ + "# run through only 25% of the test set each epoch\n", + "trainer = pl.Trainer(limit_test_batches=0.25)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YRI8THtUN7_e" + }, + "source": [ + "# Training on GPUs\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R8FFkX_FwlfE" + }, + "source": [ + "To run on 1 GPU set the flag to 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nnzkf3KaOE27" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cxBg47s5PB1P" + }, + "source": [ + "to run on 2 or 4 GPUs, set the flag to 2 or 4." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cSEM4ihLrohT" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=2)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZE6ZgwtNudro" + }, + "source": [ + "You can also select which GPU devices to run on, using a list of indices like [1, 4] \n", + "\n", + "or a string containing a comma separated list of GPU ids like '1,2'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gQkJtq0urrjq" + }, + "outputs": [], + "source": [ + "# list: train on GPUs 1, 4 (by bus ordering)\n", + "# trainer = Trainer(gpus='1, 4') # equivalent\n", + "trainer = pl.Trainer(gpus=[1, 4])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XghDPad4us74" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=list(range(4)))\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6FVkKHpSPMTW" + }, + "source": [ + "You can use all the GPUs you have available by setting `gpus=-1`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r6cKQijYrtPe" + }, + "outputs": [], + "source": [ + "# trainer = Trainer(gpus='-1') - equivalent\n", + "trainer = pl.Trainer(gpus=-1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2C-fNLm3UGCV" + }, + "source": [ + "Lightning uses the PCI bus_id as the index for ordering GPUs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_V75s7EhOFhE" + }, + "source": [ + "### `auto_select_gpus`\n", + "\n", + "You can save on GPUs by running in “exclusive mode”, meaning only one process at a time can access them. If your not sure which GPUs you should use when running exclusive mode, Lightning can automatically find unoccupied GPUs for you. \n", + "\n", + "Simply specify the number of gpus as an integer `gpus=k`, and set the trainer flag `auto_select_gpus=True`. Lightning will automatically help you find k gpus that are not occupied by other processes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Sd3XFsAOIwd" + }, + "outputs": [], + "source": [ + "# enable auto selection (will find two available gpus on system)\n", + "trainer = pl.Trainer(gpus=2, auto_select_gpus=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a5JGSBMQhJNp" + }, + "source": [ + "## analyzing GPU usage\n", + "\n", + "### log_gpu_memory\n", + "\n", + "This is useful to analyze the memory usage of your GPUs.\n", + "\n", + "To get the GPU memory usage for every GPU on the master node, set the flag to log_gpu_memory=all.\n", + "\n", + "Under the hood, lightning uses the nvidia-smi command which may slow your training down.\n", + "\n", + "Your logs can become overwhelmed if you log the usage from many GPUs at once. In this case, you can also set the flag to min_max which will log only the min and max usage across all the GPUs of the master node.\n", + "\n", + "Note that lightning is not logging the usage across all nodes for performance reasons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "idus3ZGahOki" + }, + "outputs": [], + "source": [ + "# log all the GPUs (on master node only)\n", + "trainer = Trainer(log_gpu_memory='all')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-mevgiy_hkip" + }, + "source": [ + "To avoid the performance decrease you can also set `log_gpu_memory=min_max` to only log the min and max memory on the master node.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SlvLJnWyhs7J" + }, + "outputs": [], + "source": [ + "# log only the min and max memory on the master node\n", + "trainer = Trainer(log_gpu_memory='min_max')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K82FLLIJVQG3" + }, + "source": [ + "\n", + "But what if you want to train on multiple machines and not just one?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YViQ6PXesAue" + }, + "source": [ + "# Training on multiple GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WacbBQUivxQq" + }, + "source": [ + "Lightning makes your models hardware agnostic, and you can run on GPUs with a flip of a flag. Lightning also supports training on multiple GPUs across many machines.\n", + "\n", + "You can do this by setting the num_nodes flag.\n", + "\n", + "The world size, or the total number of GPUs you are using, will be gpus*num_nodes.\n", + "\n", + "If i set gpus=8 and num_nodes=32 then I will be training on 256 GPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5iKckmDvr8zZ" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=8, num_nodes=32)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GgcSbDjjlSTh" + }, + "source": [ + "## Accelerators\n", + "\n", + "Under the hood, Lightning uses distributed data parallel (or DDP) by default to distribute training across GPUs.\n", + "\n", + "This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment variables.\n", + "\n", + "Under the hood it's as if you had called your script like this:\n", + "\n", + "1. Each GPU across each node gets its own process.\n", + "2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset.\n", + "3. Each process inits the model. (Make sure to set the random seed so that each model initializes with the same weights.)\n", + "4. Each process performs a full forward and backward pass in parallel.\n", + "5. The gradients are synced and averaged across all processes.\n", + "6. Each process updates its optimizer.\n", + "If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n_Brr7F5wdtj" + }, + "outputs": [], + "source": [ + "# ddp = DistributedDataParallel\n", + "# trainer = pl.Trainer(gpus=2, num_nodes=2) equivalent\n", + "trainer = pl.Trainer(gpus=2, num_nodes=2, accelerator='ddp')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "edxHyttC5J3e" + }, + "source": [ + "DDP is the fastest and recommended way to distribute your training, but you can pass in other backends to `accelerator` trainer flag, when DDP is not supported.\n", + "\n", + "DDP isn't available in\n", + "* Jupyter Notebook, Google COLAB, Kaggle, etc.\n", + "* If You have a nested script without a root package\n", + "* or if Your script needs to invoke .fit or .test multiple times" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZDh96mavxHxf" + }, + "source": [ + "### DDP_SPAWN\n", + "\n", + "In these cases, you can use `ddp_spawn` instead. `ddp_spawn` is exactly like DDP except that it uses `.spawn()` to start the training processes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JM5TKtgLxo37" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=2, num_nodes=2, accelerator='ddp_spawn')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sebhVE3qrhKK" + }, + "source": [ + "We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):\n", + "\n", + "* Since .spawn() trains the model in subprocesses, the model on the main process does not get updated.\n", + "\n", + "* Dataloader(num_workers=N), where N is large, bottlenecks training with DDP… ie: it will be VERY slow or won’t work at all. This is a PyTorch limitation.\n", + "\n", + "* Forces everything to be picklable.\n", + "\n", + "DDP is MUCH faster than DDP_spawn. To be able to use DDP we recommend you: \n", + "\n", + "1. Install a top-level module for your project using setup.py\n", + "\n", + "```\n", + "# setup.py\n", + "#!/usr/bin/env python\n", + "\n", + "from setuptools import setup, find_packages\n", + "\n", + "setup(name='src',\n", + " version='0.0.1',\n", + " description='Describe Your Cool Project',\n", + " author='',\n", + " author_email='',\n", + " url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK\n", + " install_requires=[\n", + " 'pytorch-lightning'\n", + " ],\n", + " packages=find_packages()\n", + " )\n", + "\n", + "```\n", + "\n", + "2. Setup your project like so:\n", + "\n", + "```\n", + "/project\n", + " /src\n", + " some_file.py\n", + " /or_a_folder\n", + " setup.py\n", + "```\n", + "3. Install as a root-level package\n", + "```\n", + "cd /project\n", + "pip install -e .\n", + "```\n", + "4. You can then call your scripts anywhere\n", + "```\n", + "cd /project/src\n", + "\n", + "python some_file.py --accelerator 'ddp' --gpus 8\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cmB3I_oyw7a8" + }, + "source": [ + "### DP\n", + "\n", + "If you're using windows, DDP is not supported. You can use `dp` for DataParallel instead: DataParallel uses multithreading, instead of multiprocessing. It splits a batch across k GPUs. That is, if you have a batch of 32 and use DP with 2 gpus, each GPU will process 16 samples, after which the root node will aggregate the results.\n", + "\n", + "DP use is discouraged by PyTorch and Lightning. Use DDP which is more stable and at least 3x faster.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OO-J0ISvlVCg" + }, + "outputs": [], + "source": [ + "# dp = DataParallel\n", + "trainer = pl.Trainer(gpus=2, accelerator='dp')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y7E2eHZKwUn9" + }, + "source": [ + "### DDP2\n", + "\n", + "In certain cases, it’s advantageous to use ***all*** batches on the same machine, instead of a subset. For instance, in self-supervised learning, a common performance boost comes from increasing the number of negative samples.\n", + "\n", + "In this case, we can use DDP2 which behaves like DP in a machine and DDP across nodes. DDP2 does the following:\n", + "\n", + "* Copies a subset of the data to each node.\n", + "* Inits a model on each node.\n", + "* Runs a forward and backward pass using DP.\n", + "* Syncs gradients across nodes.\n", + "* Applies the optimizer updates.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Y4xweqL3xHER" + }, + "outputs": [], + "source": [ + "# ddp2 = DistributedDataParallel + dp\n", + "trainer = pl.Trainer(gpus=2, num_nodes=2, accelerator='ddp2')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lhKNCnveeeq5" + }, + "source": [ + "- The second mode is ddp_spawn. This works like ddp, but instead of calling your script multiple times, lightning will use multiprocessing spawn to start a subprocess per GPU. \n", + "\n", + "However, you should be careful of mixing this mode with num_workers > 0 in your dataloaders because it will bottleneck your training. This is a current known limitation of PyTorch which is why we recommend using our ddp implementation instead.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HUf9ANyQkFFO" + }, + "source": [ + "\n", + "### mocking ddp\n", + "\n", + "Testing or debugging DDP can be hard, so we have a distributed backend that simulates ddp on cpus to make it easier. Set `num_processes` to a number greater than 1 when using accelerator=\"ddp_cpu\" to mimic distributed training on a machine without GPUs. Note that while this is useful for debugging, it will not provide any speedup, since single-process Torch already makes efficient use of multiple CPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZSal5Da9kHOf" + }, + "outputs": [], + "source": [ + "# Simulate DDP for debugging on your GPU-less laptop\n", + "trainer = Trainer(accelerator=\"ddp_cpu\", num_processes=2)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Br_btCy5lgES" + }, + "source": [ + "# Training on TPUS\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DXkBNITdv44d" + }, + "source": [ + "Another option for accelerating your training is using TPUs.\n", + "A TPU is a Tensor processing unit, designed specifically for deep learning. Each TPU has 8 cores where each core is optimized for 128x128 matrix multiplies. Google estimates that 8 TPU cores are about as fast as 4 V100 GPUs!\n", + "\n", + "A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! You can request a full pod from Google cloud or a “slice” which gives you some subset of those 2048 cores.\n", + "\n", + "At this moment, TPUs are available on Google Cloud (GCP), Google Colab and Kaggle Environments.\n", + "\n", + "Lightning supports training on TPUs without any code adjustments to your model. Just like when using GPUs, Lightning automatically inserts the correct samplers - no need to do this yourself!\n", + "\n", + "Under the hood, lightning uses the XLA framework developed jointly by the facebook and google XLA teams. And we want to recognize their efforts in advancing TPU adoption of PyTorch.\n", + "\n", + "## tpu_cores\n", + "To train on TPUs, set the tpu_cores flag.\n", + "\n", + "When using colab or kaggle, the allowed values are 1 or 8 cores. When using google cloud, any value above 8 is allowed.\n", + "\n", + "Your effective batch size is the batch size passed into a dataloader times the total number of tpu cores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "itP9y70gmD9M" + }, + "outputs": [], + "source": [ + "# int: train on a single core\n", + "trainer = pl.Trainer(tpu_cores=1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NJKnzPb3mKEg" + }, + "outputs": [], + "source": [ + "# int: train on all cores few cores\n", + "trainer = pl.Trainer(tpu_cores=8)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8a4exfWUmOHq" + }, + "source": [ + "You can also choose which TPU core to train on, by passing a list [1-8]. This is not an officially supported use case but we are working with the XLA team to improve this user experience.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "S6OrjE_bmT-_" + }, + "outputs": [], + "source": [ + "# list: train on a single selected core\n", + "trainer = pl.Trainer(tpu_cores=[2])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Afqx3sFUmfWD" + }, + "source": [ + "To train on more than 8 cores (ie: a POD), submit this script using the xla_dist script.\n", + "\n", + "\n", + "\n", + "```\n", + "python -m torch_xla.distributed.xla_dist\n", + "--tpu=$TPU_POD_NAME\n", + "--conda-env=torch-xla-nightly\n", + "--env=XLA_USE_BF16=1\n", + "-- python your_trainer_file.py\n", + "```\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ncPvbUVQqKOh" + }, + "source": [ + "# Advanced distributed training\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4MP7bEgnv7qK" + }, + "source": [ + "\n", + "Lightning supports distributed training across multiple GPUs and TPUs out of the box by setting trainer flags, but it also allows you to control the way sampling is done if you need to." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wdHiTfAMepKH" + }, + "source": [ + "## replace_sampler_ddp\n", + "In PyTorch, you must use torch.nn.DistributedSampler for multi-node or GPU training. The sampler makes sure each GPU sees the appropriate part of your data.\n", + "\n", + "```\n", + "# without lightning\n", + "def train_dataloader(self):\n", + " dataset = MNIST(...)\n", + " sampler = None\n", + "\n", + " if self.on_tpu:\n", + " sampler = DistributedSampler(dataset)\n", + "\n", + " return DataLoader(dataset, sampler=sampler)\n", + "```\n", + "Lightning adds the correct samplers when needed, so no need to explicitly add samplers. By default it will add `shuffle=True` for train sampler and `shuffle=False` for val/test sampler.\n", + "\n", + "If you want to customize this behaviour, you can set `replace_sampler_ddp=False` and add your own distributed sampler.\n", + "\n", + "(note: For iterable datasets, we don’t do this automatically.)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZfmcB_e_7HbE" + }, + "outputs": [], + "source": [ + "sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)\n", + "dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)\n", + "\n", + "trainer = pl.Trainer(gpus=2, num_nodes=2, replace_sampler_ddp=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-IOhk1n0lL3_" + }, + "source": [ + "## prepare_data_per_node\n", + "\n", + "When doing multi NODE training, if your nodes share the same file system, then you don't want to download data more than once to avoid possible collisions. \n", + "\n", + "Lightning automatically calls the prepare_data hook on the root GPU of the master node (ie: only a single GPU).\n", + "\n", + "In some cases where your nodes don't share the same file system, you need to download the data on each node. In this case you can set this flag to true and lightning will download the data on the root GPU of each node.\n", + "\n", + "This flag is defaulted to True." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WFBMUR48lM04" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=2, num_nodes=2, prepare_data_per_node=False)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FKBwXqo4q-Vp" + }, + "source": [ + "## sync_batchnorm\n", + "\n", + "Batch norm is computed per GPU/TPU. This flag enables synchronization between batchnorm layers across all GPUs.\n", + "It is recommended if you have small batch sizes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GhaCLTEZrAQi" + }, + "outputs": [], + "source": [ + "trainer = Trainer(gpus=4, sync_batchnorm=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XuFA7VTFMY9-" + }, + "source": [ + "# Debugging flags\n", + "\n", + "Lightning offers a couple of flags to make debugging your models easier:\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AKoS3fdml4Jx" + }, + "source": [ + "## Fast Dev Run\n", + "\n", + "To help you save time debugging, your first run should use the fast_dev_run flag.\n", + "\n", + "This won't generate logs or save checkpoints but will touch every line of your code to make sure that it is working as intended.\n", + "\n", + "Think about this flag like a compiler. You make changes to your code, and run Trainer with this flag to verify that your changes are bug free.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L5vuG7GSmhzK" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(fast_dev_run=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HRP1qQR5nT4p" + }, + "source": [ + "## overfit_batches\n", + "\n", + "Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it.\n", + "\n", + "Useful for quickly debugging or trying to overfit on purpose." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NTM-dqGMnXms" + }, + "outputs": [], + "source": [ + "# use only 1% of the train set (and use the train set for val and test)\n", + "trainer = pl.Trainer(overfit_batches=0.01)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c0LV0gC3nl1X" + }, + "outputs": [], + "source": [ + "# overfit on 10 of the same batches\n", + "trainer = pl.Trainer(overfit_batches=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lt3UHU6WgtS_" + }, + "source": [ + "Or a float to represent percentage of data to run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "K3yUqADhgnkf" + }, + "outputs": [], + "source": [ + "# run through only 25% of the test set each epoch\n", + "trainer = pl.Trainer(limit_test_batches=0.25)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ODN66NeVg_2o" + }, + "source": [ + "In the case of multiple test dataloaders, the limit applies to each dataloader individually.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8aQx5SLeMz1R" + }, + "source": [ + "# accumulate_grad_batches\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g8GczZXFwKC7" + }, + "source": [ + "The batch size controls the accuracy of the estimate of the gradients. Small batch size use less memory, but decrease accuracy. When training large models, such as NLP transformers, it is useful to accumulate gradients before calling backwards(). It allows for bigger batch sizes than what can actually fit on a GPU/TPU in a single step.\n", + "\n", + "Use accumulate_grad_batches to accumulate gradients every k batches or as set up in the dict. Trainer also calls optimizer.step() for the last indivisible step number.\n", + "\n", + "For example, set accumulate_grad_batches to 4 to accumulate every 4 batches. In this case the effective batch size is batch_size*4, so if your batch size is 32, effectively it will be 128." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2jB6-Z_yPhhf" + }, + "outputs": [], + "source": [ + "# accumulate every 4 batches (effective batch size is batch*4)\n", + "trainer = pl.Trainer(accumulate_grad_batches=4)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_Yi-bdTOgINC" + }, + "source": [ + "You can also pass a dictionary to specify different accumulation per epoch. We can set it to `{5: 3, 10: 20}` to have no accumulation for epochs 1 to 4, accumulate 3 batches for epoch 5 to 10, and 20 batches after that." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3xsoZ3YPgBv" + }, + "outputs": [], + "source": [ + "# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that\n", + "trainer = pl.Trainer(accumulate_grad_batches={5: 3, 10: 20})\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "myzH8mV4M1_9" + }, + "source": [ + "# 16 bit precision\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v9EaFAonwOk6" + }, + "source": [ + "Most deep learning frameworks like PyTorch, train with 32-bit floating point arithmetic. \n", + "\n", + "But many models can still achieve full accuracy using half the precision.\n", + "\n", + "In 2017, NVIDIA researchers successfully used a combination of 32 and 16 bit precision (also known as mixed precision) and achieved the same accuracy as 32 bit precision training.\n", + "\n", + "The main two advantages are:\n", + "\n", + "- a reduction in memory requirements which enables larger batch sizes and models.\n", + "- and a speed up in compute. On ampere, turing and volta architectures 16 bit precision models can train at least 3 times faster.\n", + "\n", + "As of PyTorch 1.6, NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, torch.cuda.amp. \n", + "\n", + "This package supersedes the apex package developed by NVIDIA." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjNypZPHnxvJ" + }, + "source": [ + "## precision\n", + "\n", + "Use precision flag to switch between full precision (32) to half precision (16). Can be used on CPU, GPU or TPUs.\n", + "\n", + "When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.\n", + "\n", + "If used on TPU will use torch.bfloat16 but tensor printing will still show torch.float32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kBZKMVx1nw-D" + }, + "outputs": [], + "source": [ + "# 16-bit precision\n", + "trainer = pl.Trainer(gpus=1, precision=16)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VJGj3Jh7oQXU" + }, + "source": [ + "In earlier version of Lightning, we use NVIDIA Apex for 16-bit precision. Apex was the first library to attempt 16-bit and the automatic mixed precision library (amp), has since been merged into core PyTorch as of 1.6.\n", + "\n", + "If you insist in using Apex, you can set the amp_backend flag to 'apex' and install Apex on your own." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BDV1trAUPc9h" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HK5c_aVfNV4e" + }, + "source": [ + "## amp_level\n", + "Apex includes 4 optimization levels:\n", + "O0 (FP32 training)\n", + "O1 (Conservative Mixed Precision): only some whitelist ops are done in FP16.\n", + "O2 (Fast Mixed Precision): this is the standard mixed precision training. It maintains FP32 master weights and optimizer.step acts directly on the FP32 master weights.\n", + "O3 (FP16 training): full FP16. Passing keep_batchnorm_fp32=True can speed things up as cudnn batchnorm is faster anyway.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FshMFPowNbWt" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex', amp_level='O2')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y8KEr1YvNgkC" + }, + "source": [ + "# `auto_scale_batch_size`\n", + "\n", + " \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7F1pKFIuwSFl" + }, + "source": [ + "Lightning can help you improve your model by using auto_scale_batch_size flag, which tries to find the largest batch size that fits into memory, before you start your training.\n", + "Larger batch size often yields better estimates of gradients, but may also result in longer training time. \n", + "\n", + "Set it to True to initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9_jE-iyyheIv" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(auto_scale_batch_size=True)\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yaHsJvwFhNJt" + }, + "source": [ + "You can set the value to `power`. `power` scaling starts from a batch size of 1 and keeps doubling the batch size until an out-of-memory (OOM) error is encountered.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qx0FbQrphgw1" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(auto_scale_batch_size='power')\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8bwgVF9zhZ75" + }, + "source": [ + "You can also set it to `binsearch`, that continues to finetune the batch size by performing a binary search.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QObXNs3yNrg9" + }, + "outputs": [], + "source": [ + "# run batch size scaling, result overrides hparams.batch_size\n", + "trainer = pl.Trainer(auto_scale_batch_size='binsearch')\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5OWdhSsZjqW7" + }, + "source": [ + "This feature expects that a batch_size field in the hparams of your model, i.e., model.hparams.batch_size should exist and will be overridden by the results of this algorithm. \n", + "\n", + "Additionally, your train_dataloader() method should depend on this field for this feature to work.\n", + "\n", + "The algorithm in short works by:\n", + "1. Dumping the current state of the model and trainer\n", + "\n", + "2. Iteratively until convergence or maximum number of tries max_trials (default 25) has been reached:\n", + "* Call fit() method of trainer. This evaluates steps_per_trial (default 3) number of training steps. Each training step can trigger an OOM error if the tensors (training batch, weights, gradients etc.) allocated during the steps have a too large memory footprint.\n", + " * If an OOM error is encountered, decrease the batch size\n", + " * Else increase it.\n", + "* How much the batch size is increased/decreased is determined by the chosen strategy.\n", + "\n", + "3. The found batch size is saved to model.hparams.batch_size\n", + "\n", + "4. Restore the initial state of model and trainer\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "q4CvxfZmOWBd" + }, + "source": [ + "# `auto_lr_find`\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j85e8usNwdBV" + }, + "source": [ + "Selecting a good learning rate for your deep learning training is essential for both better performance and faster convergence.\n", + "\n", + "Even optimizers such as Adam that are self-adjusting the learning rate can benefit from more optimal choices.\n", + "\n", + "To reduce the amount of guesswork concerning choosing a good initial learning rate, you can use Lightning auto learning rate finder.\n", + "\n", + "The learning rate finder does a small run where the learning rate is increased after each processed batch and the corresponding loss is logged. The result of this is a lr vs. loss plot that can be used as guidance for choosing an optimal initial lr.\n", + "\n", + "\n", + "warning: For the moment, this feature only works with models having a single optimizer. LR support for DDP is not implemented yet, it is coming soon.\n", + "\n", + "\n", + "***auto_lr_find=***\n", + "\n", + "In the most basic use case, this feature can be enabled during trainer construction with Trainer(auto_lr_find=True).\n", + "When .fit(model) is called, the LR finder will automatically run before any training is done. The lr that is found and used will be written to the console and logged together with all other hyperparameters of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iuhve9RBOfFh" + }, + "outputs": [], + "source": [ + "# default used by the Trainer (no learning rate finder)\n", + "trainer = pl.Trainer(mnist_model, auto_lr_find=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BL-gjXNCPDXk" + }, + "source": [ + "This flag sets your learning rate which can be accessed via self.lr or self.learning_rate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wEb-vIMmPJQf" + }, + "outputs": [], + "source": [ + "class LitModel(LightningModule):\n", + "\n", + " def __init__(self, learning_rate):\n", + " self.learning_rate = learning_rate\n", + "\n", + " def configure_optimizers(self):\n", + " return Adam(self.parameters(), lr=(self.lr or self.learning_rate))\n", + "\n", + "# finds learning rate automatically\n", + "# sets hparams.lr or hparams.learning_rate to that learning rate\n", + "trainer = pl.Trainer(mnist_model, auto_lr_find=True)\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RweqvpnVPPSh" + }, + "source": [ + "To use an arbitrary value set it as auto_lr_find\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4LKI39IfPLJv" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(mnist_model, auto_lr_find='my_value')\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9VAhPRKbPX-m" + }, + "source": [ + "Under the hood, when you call tune it runs the learning rate finder.\n", + "\n", + "If you want to inspect the results of the learning rate finder before doing any actual training or just play around with the parameters of the algorithm, this can be done by invoking the lr_find method of the trainer. A typical example of this would look like\n", + "\n", + "\n", + "```\n", + "trainer = pl.Trainer(auto_lr_find=True)\n", + "\n", + "# Run learning rate finder\n", + "lr_finder = trainer.lr_find(model)\n", + "\n", + "# Results can be found in\n", + "lr_finder.results\n", + "\n", + "# Plot with\n", + "fig = lr_finder.plot(suggest=True)\n", + "fig.show()\n", + "\n", + "# Pick point based on plot, or get suggestion\n", + "new_lr = lr_finder.suggestion()\n", + "\n", + "# update hparams of the model\n", + "model.hparams.lr = new_lr\n", + "\n", + "# Fit model\n", + "trainer.fit(model)\n", + "```\n", + "\n", + "The figure produced by lr_finder.plot() should look something like the figure below. It is recommended to not pick the learning rate that achieves the lowest loss, but instead something in the middle of the sharpest downward slope (red point). This is the point returned py lr_finder.suggestion().\n", + "\n", + "![image.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAgAElEQVR4Ae3dB3hUZb7H8bheey94r+URdMWOuq66a1mVtazX3nVX17p617bqursGUCMo9gIWUFBRBBGwoEAgtEDoLbQgIQklCS0kkJAQkpDyv8//xRlnkkkyM++UM3O+53nykDlz3jPnfM5/8v44NUUYEEAAAQQQQAABBFwlkOKqtWVlEUAAAQQQQAABBIQASBEggAACCCCAAAIuEyAAumyDs7oIIIAAAggggAABkBpAAAEEEEAAAQRcJkAAdNkGZ3URQAABBBBAAAECIDWAAAIIIIAAAgi4TIAA6LINzuoigAACCCCAAAIEQGoAAQQQQAABBBBwmQAB0GUbnNVFAAEEEEAAAQQIgNQAAggggAACCCDgMgECoMs2OKuLAAIIIIAAAggQAKkBBBBAAAEEEEDAZQIEQJdtcFYXAQQQQAABBBAgAFIDCCCAAAIIIICAywQIgC7b4KwuAggggAACCCBAAKQGEEAAAQQQQAABlwkQAF22wVldBBBAAAEEEECAAEgNIIAAAggggAACLhMgALpsg7O6CCCAAAIIIIAAAZAaQAABBBBAAAEEXCZAAHTZBmd1EUAAAQQQQAABAiA1gAACCCCAAAIIuEyAAOiyDc7qIoAAAggggAACBEBqAAEEEEAAAQQQcJkAAdBlG5zVRQABBBBAAAEECIDUAAIIIIAAAggg4DIBAqDLNjiriwACCCCAAAIIEACpAQQQQAABBBBAwGUCBECXbXBWFwEEEEAAAQQQIABSAwgggAACCCCAgMsECIAu2+CsLgIIIIAAAgggQACkBhBAAAEEEEAAAZcJEABdtsFZXQQQQAABBBBAgABIDSCAAAIIIIAAAi4TIAC6bIOzuggggAACCCCAAAGQGkAAAQQQQAABBFwmQAB02QZndRFAAAEEEEAAAQIgNYAAAggggAACCLhMgADosg3O6iKAAAIIIIAAAgRAagABBBBAAAEEEHCZAAHQZRuc1UUAAQQQQAABBAiA1AACCCCAAAIIIOAyAQKgyzY4q4sAAggggAACCBAAqQEEEEAAAQQQQMBlAgRAl21wVhcBBBBAAAEEECAAUgMIIIAAAggggIDLBAiALtvgrC4CCCCAAAIIIEAApAYQQAABBBBAAAGXCRAAXbbBWV0EEEAAAQQQQIAASA0ggAACCCCAAAIuEyAAumyDs7oIIIAAAggggAABkBpAAAEEEEAAAQRcJkAAdNkGZ3URQAABBBBAAAECIDWAAAIIIIAAAgi4TIAA6LINzuoigAACCCCAAAIEQGoAAQQQQAABBBBwmQAB0GKDNzY2SnFxsVRUVMi2bdv4wYAaoAaoAWqAGkiAGtB+W/tv7cfdOhAALba8Fk9KSgo/GFAD1AA1QA1QAwlYA9qPu3UgAFpsef0fhAZALSD2ALIHlBqgBqgBaoAaSIwa8OzA0X7crQMB0GLL6xddA6D+y4AAAggggAACiSFA/y1CALSoVQrIAo+mCCCAAAIIxEmA/psAaFV6FJAVH40RQAABBBCIiwD9NwHQqvAoICs+GiOAAAIIIBAXAfpvAqBV4VFAVnw0RgABBBBAIC4C9N8EQKvCo4Cs+GiMAAIIIIBAXATovwmAVoVHAVnx0RgBBBBAAIG4CNB/EwCtCo8CsuKjMQIIIIAAAnERoP8mAFoVHgVkxUdjBBBAAAEE4iJA/00AtCo8CsiKj8YIIIAAAgjERYD+mwBoVXgUkBUfjRFAAAEEEIiLAP03AdCq8CggKz4aI4AAAgggEBcB+m8CoFXhUUBWfDRGAAEEEEAgLgL03wRAq8KjgKz4aIwAAggggECrAiPmF8k/hmXLuGUbWp0m3DfovwmA4daOaUcBWfHRGAEEEEAAgVYFUr9dIh2fHSN9J+W1Ok24b9B/EwDDrR3TjgKy4qMxAggggAACrQo8PHi+CYCDZ69tdZpw36D/JgCGWzumHQVkxUdjBBBAAAEEWhW4tf9MEwDHLuUQcKtIFm+kWLR1fVMCoOtLAAAEEEAAgSgJ/PGtTBMAZxWURfwT6L/ZA2hVVBSQFR+NEUAAAQQQaFXgrJ4ZJgCu3FTZ6jThvkH/TQAMt3ZMOwrIio/GCCCAAAIIBBRoaGySTqljTAAsraoNOI3NSPpvAqBN/QgFZMVHYwQQQAABBAIKlFXVmvCnVwHXNzQGnMZmJP03AdCmfgiAVno0RgABBBBAILBA3qZKEwDP7JkReALLsQRAAqBVCVFAVnw0RgABBBBAIKDAnFVlJgB2fTMz4Pu2I+m/CYBWNUQBWfHRGAEEEEAAgYAC6Us3mAB4S7+ZAd+3HUn/TQC0qiEKyIqPxggggAACCAQUGDJnrQmAf/tifsD3bUfSfxMArWqIArLiozECCCCAAAIBBd6blGcC4LPfLAn4vu1I+m8CoFUNUUBWfDRGAAEEEEAgoMCLP+aYAPjauBUB37cdSf9NALSqIQrIio/GCCCAAAIIBBT4x7BsEwAHZq0K+L7tSPpvAqBVDVFAVnw0RgABBBBAIKDA3Z/MMQHwmwXFAd+3HUn/TQC0qiEKyIqPxggggAACCAQUuLpvlgmAU3JLAr5vO5L+mwBoVUMUkBUfjRFAAAEEEAgocP4rk0wAXFxUHvB925H03wRAqxqigKz4aIwAAggggEALgaamJjmxR7oJgEVbqlu8H4kR9N8EQKs6ooCs+GiMAAIIIIBAC4HttfUm/OlzgPX3aAz03y4PgOvWrZO77rpLDj30UNl7773l9NNPl/nzg7/pJAUUja8l80QAAQQQcLOA7vXT8Kd7AXVvYDQG+m8XB8CtW7dKx44d5b777pO5c+fK6tWrJSMjQwoKCoKuNQooaComRAABBBBAICgBPe9PA+DvX5kU1PThTET/7eIA+Oyzz8pFF10UTt1421BAXgp+QQABBBBAICICeuWvBkC9EjhaA/23iwPgKaecIk899ZTceuut0qFDBznrrLNkwIABbdZabW2taNF4foqLiyUlJcW8brMhbyKAAAIIIIBAUAJ67z8NgHovwGgNBEAXB8C99tpL9Kdbt26SnZ0tH3/8sTkP8PPPP2+13tLS0kzg09Dn+6OFxIAAAggggAAC9gL69A8NgPo0kGgNBEAXB8A99thDzj//fL/aeuKJJ+T3v/+93zjfF+wB9NXgdwQQQAABBCIvoM//1QCY9kNO5Gf+8xwJgC4OgMcee6w8+OCDfsXVr18/Oeqoo/zGtfWCAmpLh/cQQAABBBAIXeDZb5aYAPjepLzQGwfZgv7bxQHwz3/+c4uLQPScwOZ7BduqJQqoLR3eQwABBBBAIHSBv30x3wTAIXPWht44yBb03y4OgPPmzZP/+q//kt69e0t+fr4MHTpU9t13XxkyZEiQ5SPm4g8uAgmaiwkRQAABBBBoV+CWfjNNAExfuqHdacOdgADo4gCoRTN69Ghz82e9GOTkk09u9yrg5oVGATUX4TUCCCCAAAJ2Al3fzDQBcM6qMrsZtdGa/tvlAbCN2gjqLQooKCYmQgABBBBAIGiBM3tmmACYt6ky6DahTkj/TQAMtWb8pqeA/Dh4gQACCCCAgJVAfUOjCX96FXBZVa3VvNpqTP9NAGyrPtp9jwJql4gJEEAAAQQQCFpgc2WtCYCdUsdIQ2N0ngOsC0P/TQAMuigDTUgBBVJhHAIIIIAAAuEJrNxUaQLgWT0zwptBkK3ovwmAQZZK4MkooMAujEUAAQQQQCAcgVkFZSYAdn0rM5zmQbeh/yYABl0sgSakgAKpMA4BBBBAAIHwBMYu3WAC4K39Z4Y3gyBb0X8TAIMslcCTUUCBXRiLAAIIIIBAOAKDZ681AfChL+aH0zzoNvTfBMCgiyXQhBRQIBXGIYAAAgggEJ5A30l5JgCmfrskvBkE2Yr+mwAYZKkEnowCCuzCWAQQQAABBMIRSPshxwTA18etCKd50G3ovwmAQRdLoAkpoEAqjEMAAQQQQCA8gSe+yjYBcGDWqvBmEGQr+m8CYJClEngyCiiwC2MRQAABBBAIR+CugXNMAPx2YXE4zYNuQ/9NAAy6WAJNSAEFUmEcAggggAAC4Qn8b58sEwAzc0vCm0GQrei/CYBBlkrgySigwC6MRQABBBBAIByB3/WeZALgkuLycJoH3Yb+mwAYdLEEmpACCqTCOAQQQAABBEIXaGpqks490k0ALN5aHfoMQmhB/00ADKFcWk5KAbU0YQwCCCCAAALhCFTV1pvw1/HZMVJdVx/OLIJuQ/9NAAy6WAJNSAEFUmEcAggggAACoQsUllWbAHjSc+mhNw6xBf03ATDEkvGfnALy9+AVAggggAAC4QosKio3AfCCVyeHO4ug29F/EwCDLpZAE1JAgVQYhwACCCCAQOgCk1dsMgHwmveyQm8cYgv6bwJgiCXjPzkF5O/BKwQQQAABBMIVGLmg2ATAv346N9xZBN2O/psAGHSxBJqQAgqkwjgEEEAAAQRCF/h4WoEJgE8Oyw69cYgt6L8JgCGWjP/kFJC/B68QQAABBBAIV+DV9BUmAPb8cXm4swi6Hf03ATDoYgk0IQUUSIVxCCCAAAIIhC7w75GLTQB8f3Je6I1DbEH/TQAMsWT8J6eA/D14hQACCCCAQLgCD34+3wTAoXMKw51F0O3ovwmAQRdLoAkpoEAqjEMAAQQQQCB0gZs+nGEC4LhlG0JvHGIL+m8CYIgl4z85BeTvwSsEEEAAAQTCFbj0zUwTAOeu3hLuLIJuR/9NAAy6WAJNSAEFUmEcAggggAACoQt0SRtvAmB+SWXojUNsQf9NAAyxZPwnp4D8PXiFAAIIIIBAOAI7GxpN+NPnAG/ZXhfOLEJqQ/9NAAypYJpPTAE1F+E1AggggAACoQuUVNaYANgpdYw0NDaFPoMQW9B/EwBDLBn/ySkgfw9eIYAAAgggEI7Aio3bTAD8Ta8J4TQPuQ39NwEw5KLxbUAB+WrwOwIIIIAAAuEJzCwoNQHwj29lhjeDEFvRfxMAQywZ/8kpIH8PXiGAAAIIIBCOwOgl600AvK3/rHCah9yG/psAGHLR+DaggHw1+B0BBBBAAIHwBAbPWmMC4MOD54c3gxBb0X8TAEMsGf/JKSB/D14hgAACCCAQjsC7E1eaAJj67dJwmofchv6bABhy0fg2oIB8NfgdAQQQQACB8AReGLXMBMA3xq8IbwYhtqL/JgCGWDL+k1NA/h68QgABBBBAIByBx4YuNAHwk+mrw2kechv6bwJgyEXj24AC8tXgdwQQQAABBMITuP2jWSYAjlq0LrwZhNiK/psAGGLJ+E9OAfl78AoBBBBAAIFwBC54dbIJgAvWRv85wLp89N8EwHDq1NuGAvJS8AsCCCCAAAJhCdQ3NMrx3caaALhpW01Y8wi1Ef03ATDUmvGbngLy4+AFAggggAACIQsUbak24a9z93RpjMFj4HQB6b8JgCEXqm8DCshXg98RQAABBBAIXWBWQZkJgJe+GZungOgS0n8TAEOvVJ8WFJAPBr8igAACCCAQhsCI+UUmAN79yZwwWofXhP6bABhe5fzcigKy4qMxAggggAAC8s6E2N4EWsnpvwmAVl89CsiKj8YIIIAAAgjIP4cvNnsAP5iSHzMN+m8CoFWxUUBWfDRGAAEEEEBAYn0PQCWn/yYAWn31KCArPhojgAACCCAgv9wDcGvMNOi/CYBWxUYBWfHRGAEEEEDA5QK+9wAsidE9AJWc/psAaPXVo4Cs+GiMAAIIIOByAe89AHvE7h6ASk7/TQC0+upRQFZ8NEYAAQQQcLnAzIJScwFI1xjeA1DJ6b8JgFZfPQrIio/GCCCAAAIuFxgeh3sAKjn9NwHQ6qtHAVnx0RgBBBBAwOUCb8fhHoBKTv9NALT66lFAVnw0RgABBBBwucDTwxfF/B6ASk7/TQC0+upRQFZ8NEYAAQQQcLnAbR/NMgFw1KJ1MZWg/yYAWhUcBWTFR2MEEEAAAZcLeO4BuLAwdvcAVHL6bwKg1VePArLiozECCCCAgIsFdjY0ynGpY8wewJLKmphK0H8TAK0KjgKy4qMxAggggICLBTz3ADyxR7o0NTXFVIL+mwBoVXAUkBUfjRFAAAEEXCzgvQfgW5kxV6D/JgBaFR0FZMVHYwQQQAABFwt47gH410/nxlyB/psAaFV0FJAVH40RQAABBFws4LkHYLfvlsZcgf6bAGhVdBSQFR+NEUAAAQRcLOC5B+CHmfkxV6D/dnkATEtLk5SUFL+fk046KehCpICCpmJCBBBAAAEE/AQ89wD8YfF6v/GxeEH/TQCU0047TTZu3Oj9KS0tDbr2KKCgqZgQAQQQQAABP4F43QNQF4L+mwAoZ555pl9BhvKCAgpFi2kRQAABBBDYJRDPewDqEtB/EwBl3333lSOPPFKOO+44+ctf/iKFhYVBfz8poKCpmBABBBBAAAGvQDzvAagLQf/t8gCYnp4uI0aMkCVLlsj48ePl/PPPl2OPPVYqKyu9Rer7S21trSkaLRz9KS4uNucP6u8MCCCAAAIIIBCcwMz8UvMEkD/G4R6AuoQEQJcHwOZlWl5eLgceeKB88sknzd8yrwNdNKIXkRAAA3IxEgEEEEAAgYACw+cVmQB4TxzuAagLRAAkALYozHPOOUdSU1NbjNcR7AEMyMJIBBBAAAEEQhJ4OyPXBMDucbgHoC4oAZAA6FewVVVVcsghh0jfvn39xrf2ggJqTYbxCCCAAAIItC7w9NeLTADsl1nQ+kRRfIf+2+UB8JlnnpGpU6fKmjVrZObMmXL55ZfL4YcfLps3bw6q7CigoJiYCAEEEEAAAT+B2/rPMgHwxzjcA1AXhP7b5QHwjjvuMFcA77nnnnL00UeLvi4oCP5/IxSQ3/eZFwgggAACCAQlcP4rk0wAzC7cGtT0kZ6I/tvlAdC2oCggW0HaI4AAAgi4TaCuvlGOSx1jAuDmytq4rD79NwHQqvAoICs+GiOAAAIIuFCgsKzahL8Te6RLU1NTXATovwmAVoVHAVnx0RgBBBBAwIUC8b4HoJLTfxMArb56FJAVH40RQAABBFwoEO97ACo5/TcB0OqrRwFZ8dEYAQQQQMCFAvG+B6CS038TAK2+ehSQFR+NEUAAAQRcKOC5B2D/qcHfdSPSTPTfBECrmqKArPhojAACCCDgQgHPPQBHL1kft7Wn/yYAWhUfBWTFR2MEEEAAAZcJVNfVy0nPpZurgH/asC1ua0//TQC0Kj4KyIqPxggggAACLhMYs2SDCX9/eH1K3G4Bo+T03wRAq68eBWTFR2MEEEAAAZcJPDpkoQmAr6T/FNc1p/8mAFoVIAVkxUdjBBBAAAEXCeyoa5CTnxtnAuCS4vK4rjn9NwHQqgApICs+GiOAAAIIuEhg7NJdh38vfG1yXA//Kjn9NwHQ6qtHAVnx0RgBBBBAwEUCjw79+fDv2Pge/lVy+m8CoNVXjwKy4qMxAggggIBLBPTw7ynP7zr8u7govod/lZz+mwBo9dWjgKz4aIwAAggg4BKBccucc/hXyem/CYBWXz0KyIqPxggggAACLhF4/Ktsc/FHbwcc/lVy+m8CoNVXjwKy4qMxAggggIALBGp2/nL4d5EDDv8qOf03AdDqq0cBWfHRGAEEEEDABQLjlm00e/8ueDX+V/96uOm/CYCeWgjrXwooLDYaIYAAAgi4SOCJnw//vjR6uWPWmv6bAGhVjBSQFR+NEUAAAQSSXEAP/57689W/Cwu3OmZt6b8JgFbFSAFZ8dEYAQQQQCDJBcbn7Dr8e/4rk+J+82dfavpvAqBvPYT8OwUUMhkNEEAAAQRcJPCPYbuu/u3loMO/yk//TQC0+hpSQFZ8NEYAAQQQSGKBTdtqvM/+XbDWOYd/lZz+mwBo9dWjgKz4aIwAAgggkMQCT/689+/GD2c46vCvktN/EwCtvnoUkBUfjRFAAAEEklRg3pot5tYvnVLHyNLiCsetJf03AdCqKCkgKz4aI4AAAggkoUBDY5P8b58sEwBTv13iyDWk/yYAWhUmBWTFR2MEEEAAgSQU+HL2WhP+uqSNl7KqWkeuIf03AdCqMCkgKz4aI4AAAggkmcDW7XVyZs8MEwAHzVjt2LWj/yYAWhUnBWTFR2MEEEAAgSQTeO77ZSb8XfnONKlvaHTs2tF/EwCtipMCsuKjMQIIIIBAEgksX79NjksdYwLgrIIyR68Z/TcB0KpAKSArPhojgAACCCSJwPbaernpwxkm/D06dKHj14r+mwBoVaQUkBUfjRFAAAEEkkBgUVG5XPLGFBP+Tn5unKwv3+H4taL/JgBaFSkFZMVHYwQQQACBBBbQ2718MCVfft1trAl/+rxfvf9fIgz03wRAqzqlgKz4aIwAAgggkKAC68p3yG0fzTLBr+OzY0QP+1ZU70yYtaH/JgBaFSsFZMVHYwQQQACBBBGo2dkgMwtK5e0JK+X2j2ZJ5+7pJvyd+vw4Gbmg2HGPemuPlf6bANhejbT5PgXUJg9vIoAAAggkuIDeyPmugXO8gU/39nl+9Bm/a0q3J+Qa0n8TAK0KlwKy4qMxAggggIDDBXSPnyfwndd7ojzxVbYMnVMoBZurEm6vny81/TcB0LceQv6dAgqZjAYIIIAAAgkioBd56IUdGgC/nleY0IGvOTn9NwGweU2E9JoCComLiRFAAAEEEkggM7fEhL8zXswQPQcwmQb6bwKgVT1TQFZ8NEYAAQQQcLDAI0MWmACY9kOOg5cyvEWj/yYAhlc5P7eigKz4aIwAAggg4FABvfjjhO677u+nj3hLtoH+mwBoVdMUkBUfjRFAAAEEHCowMGuV2ft33fvTHbqEdotF/00AtKogCsiKj8YIIIAAAg4UaGpqksvfnmoC4Jez1zpwCe0Xif6bAGhVRRSQFR+NEUAAAQQcKLCwcKsJfyc9ly7bahLn6R6hUNJ/EwBDqZcW01JALUgYgQACCCCQ4AL/GbnEBMCnhy9K8DVpffHpvwmArVdHEO9QQEEgMQkCCCCAQMIIVNXWyynPjzMBcO7qLQmz3KEuKP03ATDUmvGbngLy4+AFAggggECCC+gNn/XGz13fzEyqGz833yz03wTA5jUR0msKKCQuJkYAAQQQcLjATR/OMAGw/9QChy+p3eLRfxMArSqIArLiozECCCCAgIME8jZVmvB3fLexUlJZ46Ali/yi0H8TAK2qigKy4qMxAggggICDBN6duNIEwAc/n++gpYrOotB/EwCtKosCsuKjMQIIIICAgwTu/Hi2CYBD5iTnvf98qem/CYC+9RDy7xRQyGQ0QAABBBBwoEBdfaOc2CPdBMD8kkoHLmFkF4n+mwBoVVEUkBUfjRFAAAEEHCKwYO0WE/5+02tCUl/96+Gm/yYAemohrH8poLDYaIQAAggg4DCBDzPzTQD8v8ELHLZk0Vkc+m8CoFVlUUBWfDRGAAEEEHCIwL2fzTUB8NPpqx2yRNFdDPpvAqBVhVFAVnw0RgABBBBwgEBDY5Oc9sJ4EwCXratwwBJFfxHovwmAVlVGAVnx0RgBBBBAwAECS4srTPg7PW28aBh0w0D/TQC0qnMKyIqPxggggAACDhAYmLXKBMD7B81zwNLEZhHovwmAVpVGAVnx0RgBBBBAwAECD30x3wTAZH/8my81/XeCBsCioiIpLi72bsu5c+fKk08+KR9//LF3XKi/vPrqq5KSkmLmE2xbCihYKaZDAAEEEHCiQGNjk5zVM8MEwIWFW524iFFZJvrvBA2AF110kQwePNgUxcaNG+XAAw+U888/Xw4//HDp2bNnyMUyb9486dSpk5xxxhkEwJD1aIAAAgggkKgCK39+/u/Jz40TvRm0WwYCYIIGwIMPPlhyc3NNnfbt21cuuOAC83tGRoYcd9xxIdVvVVWVdO7cWSZOnCiXXHIJATAkPSZGAAEEEEhkgcGz15q9f38ZODuRVyPkZScAJmgA3G+//WTNmjVmg1933XXy2muvmd8LCwtl7733DqkQ7rnnHnnqqadMGwJgSHRMjAACCCCQ4AKPf5VtAmCfiXkJviahLT4BMEED4HnnnSfPPvusZGVlmcC3ePFis+Vnz54tRx99dNBVMGzYMDn99NOlpqbGtGkvANbW1ooWjedHz0PU8wb1NQMCCCCAAAKJJNDU1CTnvjzRBMBZBWWJtOjWy0oATNAAmJmZKXoY+Fe/+pXcf//93kLo1q2b3HTTTd7Xbf2iF5IcccQRsmTJEu9k7QXAtLQ0E/g09Pn+EAC9hPyCAAIIIJAgAmtKt5vwd0L3sVKzsyFBljoyi0kATNAAqJu/oaFBtm71v2JJDwuXlJQEVR3ff/+9CXG77767eH401O22227mtc6/+cAewOYivEYAAQQQSFSB4fOKTAC8pd/MRF2FsJebAJigAXDHjh1SXV3t3fBr166Vd999V8aPH+8d194vlZWVsmzZMr+fc845R+6++24zrr32+j4FFIwS0yCAAAIIOFHgn8MXmwD4xvgVTly8qC4T/XeCBsArrrhC+vfvb4qjvLxc/vu//1uOOeYYcz5gv379wi6a9g4BN58xBdRchNcIIIAAAokicNHrk00AnLpyc6IscsSWk/47QQPgYYcdJjk5OaYQBg4caO7f19jYKCNGjJCTTz457AIhAIZNR0MEEEAAgQQSWF++w4S/41LHSFVtfQIteWQWlQCYoAFwn332Eb3liw633XabvPjii+Z3vbBD34vVQAHFSprPQQABBBCIpMCoRetMALzu/emRnG3CzIv+O0EDYJcuXURvAK2BT58CMmvWLFN0CxYsMIeDY1WBFFCspPkcBBBAAIFICvT4fqkJgL1GL4/kbBNmXvTfCRoAR44cKXvssYe5Dczll1/uLbhXXnlFrrrqKu/raP9CAUVbmPkjgAACCERD4Nr3ppsAOHrJ+mjM3vHzpP9O0AColaXPAM7OzhY9988zzJ07V1asiN3VTBSQR55/EUAAAQQSRaC2vkH03n8dn204e50AACAASURBVB0jRVt+uaNGoix/JJaT/juBA6CnAPRpHPoTj4ECioc6n4kAAgggYCOwqKjchL+zemaIPg3EjQP9d4IGQN3r17NnT3P+nz4NRH8OOugg6dWrl98ewWgXNQUUbWHmjwACCCAQaYEvZq0xAfCeT+dGetYJMz/67wQNgKmpqdKhQwfRe/7po9z058MPPzTjunfvHrMCpIBiRs0HIYAAAghESOCZEbtuAP1WRm6E5ph4s6H/TtAAeOSRR8oPP/zQouJGjRolRx11VIvx0RpBAUVLlvkigAACCERL4Ip3ppo9gBOWb4rWRzh+vvTfCRoA99prL1m5cmWLAsvNzTVPA2nxRpRGUEBRgmW2CCCAAAJREdheWy9682e9AKRkW01UPiMRZkr/naAB8LzzzpMnnniiRY09/vjjou/FaqCAYiXN5yCAAAIIREJgzqoyE/5+13tSJGaXsPOg/07QADh16lTZb7/95JRTTpEHHnjA/Ojv+++/v2RlZcWsICmgmFHzQQgggAACERAYmLXKBMCHvpgfgbkl7izovxM0AGrJrV+/XvSCj5tvvtn89OjRwzwe7qGHHopZRVJAMaPmgxBAAAEEIiDw+FfZJgC+PzkvAnNL3FnQfydwAAxUdosXLza3hAn0XjTGUUDRUGWeCCCAAALRErj4jSkmAE5buTlaH5EQ86X/JgBaFSoFZMVHYwQQQACBGAqUV9eZ8KcXgOjvbh7ovwmAVvVPAVnx0RgBBBBAIIYCutdPw5/uBXT7QP9NALT6DlBAVnw0RgABBBCIocAHU/JNANTzAN0+0H8nWAC86aabpK2frl27cg6g27/VrD8CCCCAQEABvfJX9wAOmLYq4PtuGkkATLAAeN9990kwP7EqYgooVtJ8DgIIIICArYDe+08DoN4L0O0D/XeCBUCnFSwF5LQtwvIggAACCAQS0Kd+aPjTp4Do00DcPtB/EwCtvgMUkBUfjRFAAAEEYiSgz/3VAKjPAWYQof8mAFp9DyggKz4aI4AAAgjESODtjFwTAJ8ZsThGn+jsj6H/JgBaVSgFZMVHYwQQQACBGAnc8+lcEwC/mLUmRp/o7I+h/yYAWlUoBWTFR2MEEEAAgRgINDU1yVk9M0wAXFRUHoNPdP5H0H8TAK2qlAKy4qMxAggggEAMBIq2VJvwd0L3sVJb3xCDT3T+R9B/EwCtqpQCsuKjMQIIIIBADATGLNlgAuC1702PwaclxkfQfxMArSqVArLiozECCCCAQAwEXhn7kwmA3b9bGoNPS4yPoP8mAFpVKgVkxUdjBBBAAIEYCFzdN8sEwOHzi2LwaYnxEfTfBECrSqWArPhojAACCCAQZYH8kkoT/n7dbaxs2V4X5U9LnNnTfxMAraqVArLiozECCCCAQJQF3hy/6/5/DwyaF+VPSqzZ038TAK0qlgKy4qMxAggggEAUBfT2Lxe+NtnsAfxx8fooflLizZr+mwBoVbUUkBUfjRFAAAEEoigwf80WE/5OfX6c7Kjj9i++1PTfBEDfegj5dwooZDIaIIAAAgjESECv+tXn//5zOI9/a05O/00AbF4TIb2mgELiYmIEEEAAgRgJ1NU3ypk/P/1jel5pjD41cT6G/psAaFWtFJAVH40RQAABBKIkMGH5JrP379yXJ0pDY1OUPiVxZ0v/TQC0ql4KyIqPxggggAACURJ4dMhCEwBfGr08Sp+Q2LOl/yYAWlUwBWTFR2MEEEAAgSgIbKvZKSf2SDcBcNm6iih8QuLPkv6bAGhVxRSQFR+NEUAAAQSiIKBP/NCLPy57e6rorWAYWgrQfxMAW1ZFCGMooBCwmBQBBBBAICYCfx4w2wTAD6bkx+TzEvFD6L8JgFZ1SwFZ8dEYAQQQQCDCAhsraqRT6hgTAIu2VEd47skzO/pvAqBVNVNAVnw0RgABBBCIsMDH0wpM+Lu1/8wIzzm5Zkf/TQC0qmgKyIqPxggggAACERa4/v3pJgAOmbM2wnNOrtnRfxMArSqaArLiozECCCCAQAQFanY2yPHdxpoAuK58RwTnnHyzov8mAFpVNQVkxUdjBBBAAIEICixYu+vZv+e8PJGrf9txpf8mALZTIm2/TQG17cO7CCCAAAKxExiYtcrs/Xvw8/mx+9AE/ST6bwKgVelSQFZ8NEYAAQQQiKDA419lmwDI7V/aR6X/JgC2XyVtTEEBtYHDWwgggAACMRX4w+tTTACcnlca089NxA+j/yYAWtUtBWTFR2MEEEAAgQgJlFXVmvCnTwCp2LEzQnNN3tnQfxMAraqbArLiozECCCCAQIQEpqwoMQHwj29lRmiOyT0b+m8CoFWFU0BWfDRGAAEEEIiQwNsTVpoA+PTwRRGaY3LPhv6bAGhV4RSQFR+NEUAAAQQiJHDPp3NNABw8a02E5pjcs6H/JgBaVTgFZMVHYwQQQACBCAg0NTXJmT0zTABcUlwegTkm/yzovwmAVlVOAVnx0RgBBBBAIAICa0q3m/DXuUe61NU3RmCOyT8L+m8CoFWVU0BWfDRGAAEEEIiAwKhF60wAvPHDGRGYmztmQf9NALSqdArIio/GCCCAAAIREHjxxxwTANN+yInA3NwxC/pvAqBVpVNAVnw0RgABBBCIgIDu+dP7/32fvS4Cc3PHLOi/CYBWlU4BWfHRGAEEEEDAUkDP+dNz/zQA6rmADMEJ0H8TAIOrlFamooBagWE0AggggEBMBJYWV5jwp1cB69XADMEJ0H8TAIOrlFamooBagWE0AggggEBMBPS+f7r3T+8DyBC8AP03ATD4agkwJQUUAIVRCCCAAAIxE/jn8MUmAOqTQBiCF6D/dnkA7Nevn3Tp0kUOOOAA8/P73/9e0tPTg64gCihoKiZEAAEEEIiCgD77V/cATl6xKQpzT95Z0n+7PAD++OOPMnbsWMnLy5OVK1dK9+7dZY899pCcnOAupaeAkvePA2uGAAIIOF1gW81OE/40AJZV1Tp9cR21fPTfLg+AgarxkEMOkU8++STQWy3GUUAtSBiBAAIIIBAjgRn5pSYAXvT65Bh9YvJ8DP03AdBbzQ0NDTJs2DDZc889Zfny5d7xbf1CAbWlw3sIIIAAAtEU+GBKvgmAj3+VHc2PScp5038TAGXp0qWy3377ye677y4HHXSQOSTcWrXX1taKFo3np7i4WFJSUszr1towHgEEEEAAgWgI/O2L+SYADsxaFY3ZJ/U8CYAEQKmrq5P8/HxZsGCBpKamyuGHH97qHsC0tDQT+DT0+f5oITEggAACCCAQK4EddQ3ym14TTACcv2ZLrD42aT6HAEgAbFHMl112mTz88MMtxusI9gAGZGEkAggggECMBd7KyDXh74JXJ4s+DYQhNAECIAGwRcV07dpV7r333hbjA42ggAKpMA4BBBBAIJoCq0u3S+fuux7/Nm7Zxmh+VNLOm/7b5QFQD/lOmzZN1qxZY84F1Ne77babTJgwIaiip4CCYmIiBBBAAIEICejj3u79bK7Z+/fXT+fy+LcwXem/XR4AH3jgAenYsaO58rdDhw6ih3+DDX9acxRQmN88miGAAAIIhCUwPmejCX+6B1D3BDKEJ0D/7fIAGF7Z/NKKAvrFgt8QQAABBKIroBd+6Dl/euPnN8aviO6HJfnc6b8JgFYlTgFZ8dEYAQQQQCAEAd8LP6rr6kNoyaTNBei/CYDNayKk1xRQSFxMjAACCCAQpsAavws/NoQ5F5p5BOi/CYCeWgjrXwooLDYaIYAAAgiEKHAfF36EKNb25PTfBMC2K6SddymgdoB4GwEEEEDAWqCkssac93dc6hhZtbnKen7MgIs4tQZSKITwBQiA4dvREgEEEEAgOIHZq8pMALz4jSnBNWCqdgXovwmA7RZJWxNQQG3p8B4CCCCAQCQEhs0tNAHwnk/nRmJ2zIPbuJkaYA+gxVeBAGiBR1MEEEAAgaAEXkn/yQTAF0YtC2p6JmpfgP6bPYDtV0kbU1BAbeDwFgIIIIBARAQeHjzfBMDPZqyOyPyYCecAag2wB9Dim0AAtMCjKQIIIIBAUAJXvjPNBMApuSVBTc9E7QvQfxMA26+SNqaggNrA4S0EEEAAAWuBxsYmObFHugmAei9AhsgI0H8TAK0qiQKy4qMxAggggEA7AuvLd5jw9+tuY6W+obGdqXk7WAH6bwJgsLUScDoKKCALIxFAAAEEIiQwM7/UBMBL38yM0ByZjQrQfxMArb4JFJAVH40RQAABBNoRGDJnrQmA+iQQhsgJ0H8TAK2qiQKy4qMxAggggEA7Ai+PWW4C4Is/5rQzJW+HIkD/TQAMpV5aTEsBtSBhBAIIIIBABAUe/HzXLWC+mLUmgnNlVvTfBECrbwEFZMVHYwQQQACBdgQue3uq2QM4beXmdqbk7VAE6L8JgKHUS4tpKaAWJIxAAAEEEIiQQENjk3TuvusWMEVbqiM0V2ajAvTfBECrbwIFZMVHYwQQQACBNgSKt1abvX8aAjUMMkROgP6bAGhVTRSQFR+NEUAAAQTaEJiet+sWMH98i1vAtMEU1lv03wTAsArH04gC8kjwLwIIIIBApAUGz951C5gHP58X6Vm7fn703wRAqy+BEwtozqoyue2jWfL3LxfIxooaq/WjMQIIIIBA/AR6jd51C5iXRi+P30Ik6Sc7sf+ONXVKrD8wmT7PSQWk54o8OmShOV+k47NjzL9nvJghPyxe3yp5U1OTbNleJ3py8fL122Temi2iDxufsHyTLFi7VQrLqqW6rr7V9pF8Q593qc+5zFlfIZu21UhdfeiPPNL10XZVtfVmvXQ+68p3xGwdIunBvBBAAIEHBs0zf8u/nL0WjAgLOKn/jvCqBT07AmDQVC0njFYBfbOgWG7uN1P0/k//GblEXk1fIQOmrRIdP2VFiSwuKjehbXttvQk3b2fkeh8WflzqGHn2myVy7XvTvWHw0aELZev2OrMCeiKx7iVM+yFHftd7kncaT2gM9O8pz4+TC1+bLFe8M1Wue3+62cP410/nysOD58s/hy8283orI1c+mlogQ+cUyugl60VvWbCoqFxWba6SzZW1UlJZYwLlyk2VsqS43CzD8HlF8sKoZXJr/5ly2gvjWyzL6WnjRR9/9L99sqTrW5lywauT5Te9Jogujz4X8/huY0XXV386pe4KvYGWX8dpm4vfmGJcdbm7f7dU1O3zmWvM8urjltR2xPwisx69x/5k1u2Jr7LlyWHZ8vTXi+SZEYvl3yMXS7fvlkrPH5fLa+NWSJ+JeWb6fpkFom3+NWKx6OEa3X5X982Suz+ZY9rqex9PK5BvFxbL5BW7AnbB5iopraqVnTzfs+WXizEIIGD+7unfrxn5pWhEWCBa/XeEFzOqsyMAWvBGq4DeGL+iRRhqLdj4Bp87Pp5l9uTpKmmoeHfiShOStO05L080YfK3L01sMe+TnxsnOv6SN6bINe9lyfXvTzeB76Tndt1+oLXPjvT4zj3SzXJosIvEvHU+J3SPzLwisTxtzUMDrYbg3740wQRdvffXXQPnmKCqwXF8zkZZsXGbbKvZKbqnkwEBBJJboL6h0fv3S4/wMERWIFr9d2SXMrpzIwBa+EargHSvWfrSDaLPgHxvUp7oI4D+MSzb7E3SvUrnvzLJu8dPQ4XunRu3bEPAYKB72/QKMt/w0SVtvNm7NXH5JqnZ2dCqgAYNPZyqh2YXFm4V3Uume6/GLt1g9kbqYYn+UwtEA+vzo5bJU18vMnu/9BzEP707zQQZ3Yunn61BVYOm7sHT5e/6Zqbc+fFs0XNbvssultyNlaJ/8HTQw8G6xzK/pMrsKczMLZHZq8rMnk+dTg9N6/mNeoi3RH8qd/3o3rSKHTvNOnlumeBZh9Wl280hbnXVO+q/M2Gl9Ph+qTlX8rb+s4yR2uqeTV0PXa4PM/Pl0+mrZWDWKrP3TvfyfTAl3wRr3fun20X3Bj49fJHx1Db6/ldzC8328OxR1HZ6Lo9uQw11GrIven2y6Hbw3S7B/q57M3XP6O0fzRLdQ6l7I7VOdHvotplVUGb2EBMUWy1t3kDA8QL6d07/Juh/jPVvIkNkBaLVf0d2KaM7NwKghW88C0g7dz0/b335Dm9wam1VNORpMNGQpmEqnPPrWpt3MOP1jxdhJLCUBtWK6p0m0GrQ1r18euh8ZkGpORz95vhcefyrbHPoPdTAqHsT9TQC3fY6Pw3zDAggkBgCU1duNgHw8renJsYCJ9hSxrP/dgoVAdBiS1BAFng0DUtAQ78GRT2PUy/w0b2TujdSz/t86Iv55nxK3eMb6NC37oXVQ8t6PuOgGavNhT476lrfAxzWAtIIAQQiIqDnJ+sewL99MT8i82Mm/gL039wGxr8iQnxFAYUIxuQxE9C9vnoltwZEvQhID7u3dohZD8tf1SdL7vtsrqR+u9QcTtbTAzZU7GDPbcy2GB+EgL+AXqin31m9gIwh8gL03wRAq6qigKz4aBxjAT1XUs/h1IuD9PYSgS4Iah4Sz+41wZx7qnsZta1ehMKAAALRF7j3s7kmAOqdFRgiL0D/TQC0qioKyIqPxnEW0PMy9WKbnzZsM/d/HDa30Fwco4eI9SKeQFdj6y139BZDL49ZLpN+2iR6KyIGBBCIvIDelUH/Q6bn7zJEXoD+mwBoVVUUkBUfjR0uoIeR9Z6TejW63o/S0yH57iU8sUe6OUdJ71GpF7MwIICAvYDexsvzHzA9FYMh8gL03wRAq6qigKz4aJyAAnr7nVGL1knqt0vMrWx8w6Dey1BvfP3J9NXmfpTcuiIBNzCL7AgBvW2Vfrf0Xqx8j6KzSei/CYBWlUUBWfHROMEF9BCyHj7Weype+c4002H5BsIze2aYp8XoFcf5JZVcUJLg25vFj52A3kNUv0t6KgZDdATovwmAVpVFAVnx0TjJBPQG5vo4wHs+nWsevecbBvX3P7w+xTw2MCtvc8zvRZlk1KxOkgt8NmO1CYD/N3hBkq9p/FaP/psAaFV9FJAVH42TWEDPYdKnx+jTVPSwcOfu/o8VPPX5ceYG1/qEF24SnsSFwKqFJaDPSNf/NOlz4BmiI0D/TQC0qiwKyIqPxi4S0KuF9XnGejFJ89vP6JMO9Ka33GLGRQXBqrYpoI+k1AD49TxuAdMmlMWb9N8EQIvyEaGArPho7FIBPak9u3Cruem0Ph/ac6hYn3GsF5csW1fhUhlWG4FdAnq6hH4v9Ik/DNERoP8mAFpVFgVkxUdjBMxeP937p3sBPUFQ/73hgxkyckGx6K1oGBBwk4A+q13vt6nfg5JtNW5a9ZiuK/03AdCq4CggKz4aI+AV0PMAdW/H419l+z3HWK8k1ptOF22p9k7LLwgks0DepkoT/nSPOOfHRm9L038TAK2qiwKy4qMxAgEFNlfWygdT8uWCVyd79wrqHpFHhiwwzzcO2IiRCCSJgF44pXv//jxgdpKskTNXg/6bAGhVmRSQFR+NEWhToKGxyTx/WK8i9j08fOOHM2TMkg2i7zMgkGwCeu8/rXd9NCND9ATovwmAVtVFAVnx0RiBoAVWbNwm/x652O92Mlf1yZJpKzcHPQ8mRMDpAp7Dvyd0Hyvl1XVOX9yEXj76bwKgVQFTQFZ8NEYgZAE9PPx2Rq6cnjbeu1dQ9xDmrOfK4ZAxaeA4gbcyck1dPzBonuOWLdkWiP6bAGhV0xSQFR+NEQhbYOv2Ouk1ern3gpFOqWPk6a8XyfryHWHPk4YIxFNAL/i4+I1dt3/R520zRFeA/psAaFVhFJAVH40RsBYoLKs2Vw57zhHU+wr2nZTH7WOsZZlBrAWWFJebvX8nPZcueuN0hugK0H8TAK0qjAKy4qMxAhETWFxULrf2n+k9LKxXEI9duoHbaERMmBlFW+Cl0ctN/T42dGG0P4r5Cw9y0CJIoRLCFyAAhm9HSwQiLaCH0H5YvF5+/8okbxC84+NZsnz9tkh/FPNDIKIC+nSc3/XeVbcZORsjOm9mFliA/psAGLgyghxLAQUJxWQIxFCguq5e3p6wUk7skW6CoJ4fqFcQb+KpCjHcCnxUKAKzV5WZWu2SNl5q63n6TSh24U5L/00ADLd2TDsKyIqPxghEVUCfHvLo0IXevYF6fuC7E1eKBkQGBJwk0O27paZO9T8qDLERoP8mAFpVGgVkxUdjBGIisGDtVrnpwxneIHhe74kyYn6R6GE3BgTiLbCzoVHO6plh6nN6Xmm8F8c1n0//TQC0KnYKyIqPxgjETEDPD9Snh1z0+i+Pl7v+/ek8Wi5mW4APak1gyooSE/5++9JEnm7TGlIUxtN/EwCtyooCsuKjMQIxF9Dzqz6aWiCnvfDLjaSf+nqRbKyoifmy8IEIqIDWn97GKO2HHEBiKED/TQC0KjcKyIqPxgjETaCkssZcGKIXiGjnq+cHvj+5nfsHNjWJlJaKrFmz6199zYCAhcCOugY59flxpgb1VAWG2AnQfxMAraqNArLiozECcRfQm+/e3O+X+wde+NpkSW9+/8DycpE+fUR+/WuRlJRffvS1jtf3GRAIQ0BrTf8Dovet1NMUGGInQP9NALSqNgrIio/GCDhCQDteffRW8/sH/rRhm8j48SL77Sey2267fnwDoGecvq/TMSAQooDn8O/LY5aH2JLJbQXovwmAVjVEAVnx0RgBRwk0v3/gPbf3lMZf/UqafvWrX/b6+QZAz+/6/u67EwIdtTWdvzB69a/e90/3AM5bs8X5C5xkS0j/TQC0KmkKyIqPxgg4UqB4a7U8M2CqbN9jb2lI2a3t8OcbAnVPIIeDHblNnbhQM/JLTfg7u9cErv6Nwwai/yYAWpUdBWTFR2MEnCvQp4806SFeT8AL5l+dvm9f564TS+YogRdGLTMB8D8jlzhqudyyMPTfLg+Ar7zyipxzzjmy//77S4cOHeSGG26Q3NzcoOufAgqaigkRSBwBPRlfL/AIJwBqO07mT5xtHacl1fNOPeecTl6xKU5L4e6Ppf92eQD805/+JIMGDZKcnBxZvHixXH311XLsscfK9u3bg/pmUEBBMTERAokloLd6CWaPX2vTlJUl1vqytDEXWFpcYfb+nfL8OKnZybN/Y74BRIT+2+UBsHnRbd68WVJSUmTatGnN3wr4mgIKyMJIBBJbQO/z11q4C2a8tmdAoA2BtzJyTQB8ZMiCNqbirWgK0H8TAP3qKz8/3wTAZcuW+Y1v7QUF1JoM4xFIYAH2ACbwxkuMRb/ynWkmAH6fvS4xFjgJl5L+mwDoLevGxka55ppr5MILL/SOa/5LbW2t2W2shaM/xcXFJjDq7wwIIJAkAmGeA6gXjTRxDmCSFEH0VmNN6XYT/n7dbaxUVO+M3gcx5zYFCIAEQG+B/P3vf5eOHTuaUOcd2eyXtLQ0E/j0MLHvDwGwGRQvEUh0AX3CR4gXgTSm7CYDbn1SctZXJPras/xRFBgwbZUJgHcNnBPFT2HW7QkQAAmApkYee+wxOeaYY2T16tVt1gx7ANvk4U0EkkdA7+en9/Vr7ybQP58T2Ljbr6R6j72ly5Nfy3GpY6TX6OWyvbY+eTxYk4gJ3PLzowe/mMW5ohFDDWNGBECXB0C9FF/D31FHHSV5eXkhlxAFFDIZDRBIHAF9vJs+4aO9EPjzk0C2fjdaHh2y0Ozd0ac76G0+vllQzE1+E2eLR31JN1fWSqfUMaZG1pfviPrn8QGtC9B/uzwAPvLII3LQQQfJ1KlTZePGjd6fHTuC+2JSQK1/uXgHgaQQCPZZwBkZ3tWdklsiF70+2RsE9YT/ics3if6Hk8HdAsPmFpq6uO796e6GcMDa03+7PAD6nsfn+7veGzCYgQIKRolpEEhwAT0crE/40As8fG8Do691fEXLc/521DVIv8wC77NedY+gHvrjma8JXguWi3//oHkmAL4/OfQjTpYfTfNmAvTfLg+Azeoh5JcUUMhkNEAgcQV0D57e5Fnv86f/BrFHT6/yfDV9hZzYI927R1Dv/bahIrijDImLxZI3F6iqrZfOP9fByk2Vzd/mdYwF6L8JgFYlRwFZ8dEYAdcIbKyokdRvl5oLRHRvoD4B4qOpBVJX3+gaA7ev6KhF68x/Ai55YwqnAzigGOi/CYBWZUgBWfHRGAHXCSxfv80cCtYQqD+XvT1VZhXw6LhkL4SdDY3S9a1Ms83fzgj+efPJ7hLP9aP/JgBa1R8FZMVHYwRcKdDY2CQjFxTL2b0meA8LP/X1IimrqnWlhxtW+rMZq8221m1eWcPNn52wzem/CYBWdUgBWfHRGAFXC+j5gc+PWua9LciZPTNk+PwiDg8mWVWUV9fJGS9mmAA4dE5hkq1d4q4O/TcB0Kp6KSArPhojgICILC4ql6v6ZHn3Bt7x8SxZtbkKmyQRSPshx2zbP707jXtCOmib0n8TAK3KkQKy4qMxAgj8LFDf0CgfTyuQk57bdbVw5+7p8sb4FVKxg8OFiVwkBZurRJ/5q+d7Ts8rTeRVSbplp/8mAFoVNQVkxUdjBBBoJlC0pVru+XSud29gl7Tx8sGUfB4r18wpUV4+8PN9/x78fF6iLLJrlpP+mwBoVewUkBUfjRFAIICAPjFkfM5G0SeIeK4W/u1LE+ST6aulZmdDgBaMcqJAVt5ms/10DyCH9J23hei/CYBWVUkBWfHRGAEE2hBoaGwSvXec3jfOEwTP6z1R9IpSgmAbcA54Sw/pewL8iz/mOGCJWITmAvTfBMDmNRHSawooJC4mRgCBMAT0HnL6DNnzX5nkDYLnvjyRPYJhWMaqycCsVWZb6ZXdehUwg/ME6L8JgFZVSQFZ8dEYAQRCEKitb5AvZ6+VC16d7A2Cv31pomjY0GcPMzhDYNJPm7xPfBk8e60zFoqlaCFA/00AbFEUoYyggELRYloEEIiEgD4+Tu8n5x8EJ8iAaaukuq4+Eh/BPMIUWFpcISc/N84E9P+MXMI9HcN0jEUz+m8CoFWdUUBWfDRGAAELAQ2CzY3rMAAAFkdJREFUemj4wtd+2SOoT5rQZwwTBC1gw2y6rnyHnPPyRBP+7v5kjuihewbnCtB/EwCtqpMCsuKjMQIIREBAg8bweUXyh9d/uVjkN70mSL/MAm4fEwHfYGaxrWanXPHOVBP+9IbP+prB2QL03wRAqwqlgKz4aIwAAhEU0CA4Yn6RXOxz1fBZPTPMfQSrajk0HEFqv1npnti/DJxtwp9enLO+fIff+7xwpgD9NwHQqjIpICs+GiOAQBQE9BYk3ywolkvfzDShRG8ho1ej9p2UxxWpEfbW0P3Y0IXG+dTnx8mydRUR/gRmFy0B+m8CoFVtUUBWfDRGAIEoCmgQ/D57nXR965cgqBcovDBqmRSWVUfxk90xa70Xoz7hQwP2Cd3HypTcEneseJKsJf03AdCqlCkgKz4aI4BADAQ8N5S+qk+Wd4/gcalj5JEhC2Rh4dYYLEHyfYQeUr/z412HfU/skU74S8BNTP9NALQqWwrIio/GCCAQQwF9xNyM/FK/Zw3r3qvr3p8uw+cX8XSRILeF3tj5hg9mmDB92gvjZfaqsiBbMpmTBOi/CYBW9UgBWfHRGAEE4iSwYuM2eWbEYuncPd27V/CMFzPkpdHLZXXp9jgtlfM/dnNlrehVvp7zKhcXlTt/oVnCgAL03wTAgIUR7EgKKFgppkMAAScKlFXVmtvF+N5UWsPN376YL4sIN36bLLtwq/fm23q/v9yNlX7v8yKxBOi/CYBWFUsBWfHRGAEEHCKg5wlOXrFJ7vtsrnRKHePdK6i3N5mZX+rqJ1roofPPZqw2F3poOL7kjSmyhr2kDqnc8BeD/psAGH71iAgFZMVHYwQQcKBAfkml/HP4Yjm+21hvELz+gxny7sSV8uPi9bJ8/TbXPHtYb+isF8to8NOfv3+5gJs8O7Bmw1kk+m8CYDh1421DAXkp+AUBBJJMoHhrtblljF7l6glAvv/qYePb+s+SR4culBd/zDGHkvX+g3roWG+OnOiDPtdX9/bpOuttXnQvoO4NZEgOAfpvAqBVJVNAVnw0RgCBBBDQCx8GZq2Sf49cLDf3m2luKu0bBAP93rlHuplWLyoZs2SDlFbVJsCa7lpEvbBDz4H0rJcGXT3/jyG5BOi/CYBWFU0BWfHRGAEEElRgy/Y6mb9mi4xesl4+nb5aXk1fIU8PX2QeiaaPn/OEJ8+/ejj5gUHzZNyyDY7cO6h79mYVlMndn8zxLrueC/n4V9mydXtdgm4lFrstAfpvAmBb9dHuexRQu0RMgAACLhPQMKW3kvl2YbH0+H6p97YpnjCoATHthxzRQ6yxOqSqn6OHtPUcxl6jl8sTX2XL/YPmyW0fzZKr+2Z5r+7VZdSwqmFWz4VkSF4B+m8CoFV1U0BWfDRGAAGXCOSXVJm9hOe+PNG7h03Dlh5e1TCoe9/00XWRGvTiDb3p9YeZ+fLQF/NFb9viCaCt/auHrTWwFm3hMXmR2g5Ong/9NwHQqj4pICs+GiOAgMsENOTpM3P1whF9LrFvGDuzZ4Y8OSzbXGyhh5er6+rb1dFHsulNrScu32TOU9T2Xd/85dnHvvP/dbexcv37082FLZ9MXy1fzys0h7Azc0vM4Ww9rM3gHgH6bwKgVbVTQFZ8NEYAARcL7KhrkAnLN5knkgQ6b1CfV3z521PNrVf+b/ACc49CvS/hrf1nmsO2gdr4Br4LX5tsbuEyYNoqE/Bqdja4WJtVby5A/00AbF4TIb2mgELiYmIEEEAgoIDuGdTDwO9MWGkuFjmvd/uHbD1hT/ccXvNelgmKfSbmmT2M+oQTBgTaEqD/JgC2VR/tvkcBtUvEBAgggEBYAiWVNebpJHr/vcGz18rweUXyffY6Gbt0gxn/04ZtUlmzM6x50wgB+m8CoNW3gAKy4qMxAggggAACcRGg/yYAWhUeBWTFR2MEEEAAAQTiIkD/TQC0KjwKyIqPxggggAACCMRFgP6bAGhVeBSQFR+NEUAAAQQQiIsA/TcB0KrwKCArPhojgAACCCAQFwH6bwKgVeFRQFZ8NEYAAQQQQCAuAvTfBECrwqOArPhojAACCCCAQFwE6L8JgFaFRwFZ8dEYAQQQQACBuAjQfxMArQqPArLiozECCCCAAAJxEaD/JgBaFR4FZMVHYwQQQAABBOIiQP9NALQqPArIio/GCCCAAAIIxEWA/psAaFV4FJAVH40RQAABBBCIiwD9NwHQqvAoICs+GiOAAAIIIBAXAfpvAqBV4VFAVnw0RgABBBBAIC4C9N8EQKvCq6iokJSUFCkuLhYtJn4woAaoAWqAGqAGnF8D2m9r/639uFuHFLeueCTW21NAWkT8YEANUAPUADVADSRWDWg/7taBAGix5RsbG83eP/0fhP6P78QTT2yxF7C9cc3f97z2hEv9NxL/m/TMN5h5tTdta+8HGt/euObve16z/rv+d8r2T+z6D/R3wVPjnu+i72vP74lY/4HWNdA4zzqy/rv2kvl6eH5n+0f/75/22+qs/bhbBwJgBLf8Kaec0mJu7Y1r/r7ntf5x1P9J6r+RGDzzDWZe7U3b2vuBxrc3rvn7ntesP9s/Gepfv2+emvZ899p67XkvEes/0LoGGudZx0AenvdY/8T7/gfa1oHGebaxU7e/Z7nc8C8BMIJb+YMPPmgxt/bGNX/f8zrSfwA9822xgAFGtDdta+8HGt/euObve16z/pHtADyuATZ3i1HtTdva+4HGtzeu+fue18my/RXXs04e6LZee95LxPUPtK6BxnnWMZCH5z3WP/G+/4G2daBxnm3s1O3vWS43/EsAdOhWjvQfQIeuZquLxfpHtgNoFdqhb7D92f6R3APs0DJvdbGof3fXf6uFEeE3CIARBo3U7GprayUtLU30XzcOrD/bn/rn+8/fP/7+u7H/i9U6EwBjJc3nIIAAAggggAACDhEgADpkQ7AYCCCAAAIIIIBArAQIgLGS5nMQQAABBBBAAAGHCBAAHbIhWAwEEEAAAQQQQCBWAgTAWEnzOQgggAACCCCAgEMECIAO2RAsBgIIIIAAAgggECsBAmCspKP4Oe+8846ceuqp5okDTzzxhDQ1NUXx05w169zcXDnzzDO9P3vvvbd8//33zlrIKC/N6tWr5dJLLzXb//TTT5ft27dH+ROdNfuOHTtKly5dTA2ogxuH6upqOfbYY+WZZ55x1eqXl5fLb3/7W7PtTzvtNBkwYICr1r+oqEguueQS893X78CIESNctf66sjfeeKMcfPDBcsstt7hu3W1XmABoKxjn9ps3b5bjjz9eampqpKGhQS644AKZNWtWnJcqPh9fVVUlhx12mOsC0MUXXyxZWVkGfcuWLVJfXx+fDRCnT9UAqNvezUP37t3l9ttvd10A1L95Gn510P/4dOrUScrKylxTChs2bJBFixaZ9d24caMcddRRrvv7l5mZKT/++CMBMIyqJwCGgeakJhoA9X/++j9hDYHnnnuuFBQUOGkRY7YsQ4cONZ1gzD7QAR+Uk5Mjl112mQOWJH6L4PYAmJeXJzfffLMMGjTIdQHQt+r0Pz9aC6Wlpb6jXfX7GWecIbpX0G2DhkD2AIa+1QmAoZuF1GLatGly7bXXypFHHin6aKNAhyf12Yj6h2uvvfaS8847T+bOnRvSZ7z33ntywAEHyCGHHCLdunULqW20J47F+nvW4YYbbpBvv/3W89IR/0Z7/bWedL21xn7zm99I7969HbHenoWI9vrr5+hen7PPPlvOOeccGTJkiOejHfFvLNb/+uuvl5UrVzoyAMZi/fU/vxp89tlnnxbPXY53EcRi/T3ruGDBAtHD4E4aYrX+BMDwtjoBMDy3oFulp6dLjx495LvvvgsYAL/++mvZc8895bPPPpPly5fLQw89ZM5nKCkp8X6GnuOmX+zmP+vXr5etW7fKlVdeKfq/3x07dpjzQfRL55Qh2uvvWU99dmaHDh3MXlDPOCf8G+31HzlypBx66KHmf/362Cw9B27ChAlOWHWzDNFef/2QdevWmc/Sw2F6LuySJUtcs/6jRo2Sf/3rX2Z9nbgHMBbb37OxN23aZE6B0X+dMsRq/fXvv9b+zJkznbLqZjlitf4EwPA2OwEwPLewWgXaA6h7/B577DHv/BobG815HK+++qp3XFu/6Em/jz76qHeSN954Q15//XXvayf9Eo3196zf4MGD5a677vK8dOS/0Vh/Pd9T/wPgGXT7648Th2isf/P11DCkQciJQzTWPzU1VY455hhzBEHPfz3wwAOlZ8+eTlz9gP8Btv3713xFH3nkEdH/FDlxiMb21/XU//j94Q9/EP0b6OQhWuuv60wADG/LEwDDcwurVfMvQF1dney+++4tDgvfc889ood1ghlmz54tZ511lvcikKuvvlp0r4ATh2isv2c99RCongjs5CEa668XfOj21z3B+p8HdRg9erQjGaKx/nrif2VlpVlfvRBEDwXPmzfPNevvu6JO3APou3zR2P66t8+z/SsqKsxRkqVLl/p+rGN+j8b66x0f7rzzTklLS3PMera2INFYf89nEQA9EqH9SwAMzctq6uZfAD2Eq+OaX7X773//25wLGOyH6RWAJ598sjkE4OTbwERr/fUP/xFHHCEaqJ08RGv99TCL3v5FTxF4+umnHUsQjfVftWqVOf9LzwHT9e/Tp4+r1t93ZRMtAEbi75+eL62nyOj219ugfPTRR74kjvo9GvU/ffp02W233by3wVKLRAnAkdj+uoH1IrjDDz/cnAN69NFHt+hPHVUEDlsYAmAMN0g0/gDEcPGtP4r1978IKFJ/AK03TIxmwPZn+/teBEf9R2YHQIy+vtYf4/bvvzVgFGZAAIwCamuzbP4FiMQh4NY+y4njWX//AMD2tz8Fwol13toyUf/Uv28A5vvvru9/a38X4jmeABhD/eYdgH60ngT9+OOPe5dCz+PS3djBXgTibZgAv7D+/h0g25/65/vP3z/+/ruj/3NiF00AjPJW0RPT9U7t+qMBSB/bpr8XFhaaT9bbwOj9/z7//HP56aef5OGHHza3gXHSrQxsiFh/tj/1z/efv3/8/Xdj/2fTd8aiLQEwysp6dZIWfvOfe++91/vJ77//vnmah94PUPcIzJkzx/teov/C+rP9m9e+vqb++f57/rbx94+///o0q2Ts/zw17tR/CYBO3TIsFwIIIIAAAgggECUBAmCUYJktAggggAACCCDgVAECoFO3DMuFAAIIIIAAAghESYAAGCVYZosAAggggAACCDhVgADo1C3DciGAAAIIIIAAAlESIABGCZbZIoAAAggggAACThUgADp1y7BcCCCAAAIIIIBAlAQIgFGCZbYIIIAAAggggIBTBQiATt0yLBcCCCCAAAIIIBAlAQJglGCZLQIIJIZAx44d5d13302MhWUpEUAAgQgJEAAjBMlsEECgdQF99NsNN9zQ+gRxfGfz5s1SXV0dxyVo+6OdbNf2kvMuAgg4WYAA6OStw7IhkCQC8QgxO3fudLResMsXDztHw7FwCCAQEQECYEQYmQkCCLQl0F6IWbZsmVx11VWy3377yRFHHCF33323lJaWemc5btw4ufDCC+Wggw6SQw89VK655hopKCjwvr9mzRpJSUmRr7/+Wi6++GLZa6+9ZNCgQeL53DfffFP+53/+x7R99NFHxTd8NT8ErPMZOHCg3HjjjbLPPvvICSecID/88IP3s/QXfa3j9XMuvfRS+fzzz83nl5eX+03n+0Ln269fP7nuuutk3333lbS0NGloaJAHHnhAOnXqJHvvvbeceOKJ0qdPH28znUbb+f5kZmaa94uKiuS2224zJocccohcf/31og4MCCCAQDACBMBglJgGAQSsBDxBLNBMNDR16NBBunXrJitWrJDs7Gy54oorpGvXrt7Jv/nmG/n2228lPz9fFi1aZEJUly5dpLGx0UzjCYAapHS61atXy4YNG0wAPPDAA+Xvf/+7mffo0aNN+BowYIB33oEC4DHHHCNfffWV+bx//OMfsv/++8uWLVtMG533HnvsIf/6178kNzdXhg0bJkcffXRQAVDD7WeffSarVq2SwsJCE0RfeOEFmT9/vlnmIUOGmOUbPny4+ayqqiq5/fbbTTjeuHGj6E9dXZ1pd8opp5jwuHTpUvnpp5/kL3/5i5x00knmfe/K8QsCCCDQigABsBUYRiOAQOQE2gqAL730klx55ZV+H1ZcXGwC1cqVK/3Ge17o3kHdK6Z7DnXwBEDfvWc6Xj9XA57uafMMutfsjjvu8Lw07/teBKLzfe6557zvb9++3XyW7oXU4dlnn5XTTz/d+77+0qNHj6AC4FNPPeXXLtCLxx57TG655RbvW4HsvvzySxP2mpqavNNpMNQ9lhkZGd5x/IIAAgi0JkAAbE2G8QggEDGBQCHGM/Nbb73V7FHTw7++PxrE0tPTzWR5eXly5513ynHHHScHHHCAmU7fHzt2rHnfEwBnzJjhma35Vz/36quv9hune/R89y4G2gM4YsQIvza6F/GLL74w4/TQ8P333+/3vh4S1uVp7xCw7uFrPnzwwQdy9tlny+GHH27WS/cunnvuud7JAtnp3sfdd9/dz0vtdtttN3OY2duYXxBAAIFWBAiArcAwGgEEIicQKMR45q7n/t18883mcKse4vX90b1vOuihTd1LOGnSJHO4MycnxwSu77//3rzvCYB6eNh3CPS5Tz75pFxyySXeyQIFQM98PRPpuYd6TqEONgGw+Xz18LGe+/fhhx+aQ9+67g8//LCceeaZno/2nsfoHSFiDmmfd955flYet4qKCt9J+R0BBBAIKEAADMjCSAQQiKRAoCDmmX/37t1NwKuvr/eM8vu3rKzMhL2srCzv+OnTp8ctAOohYD3/0HfQQ8bB7AFsHgAff/xx+eMf/+g7K7nsssv8AuBDDz0k1157rd80eg6jXvixbds2v/G8QAABBIIVIAAGK8V0CCAQtoAGQL1aVvfQ+f7olazr1683F4HooeB58+aZq3vHjx8v9913nzl3Ty/0OOyww8yVwbqXa/LkyeYQqQYuT6CK5R5Az0Ug//nPf0TPUdQLNvSiEV2etva++S6vB7Jv376ih5d1fXVeGiT1te8ewN69e8uxxx5rLjjRcx/1Cma9b2Hnzp2NqQZjXSa9OviJJ54QPX+SAQEEEGhPgADYnhDvI4CAtYAGQA1AzX8efPBBM289x++mm26Sgw8+2FzIcPLJJ4teMOG5yGHixImiV73qbVfOOOMMmTp1qplXPAKgLnDz28D079/fLE9NTU2rVoECYG1trQm6eohZ1/2RRx6R1NRUvwCoN6rWq6L1SmSdh+c2MHpF8D333GPOHVSX448/XnRvIXsFW90EvIEAAj4CBEAfDH5FAAEEwhF4+eWXzV7AcNrSBgEEEIiHAAEwHup8JgIIJLSAXrShh6v1fn6DBw82N2PWW8EwIIAAAokiQABMlC3FciKAgGME9PD0kUceaQ5J67l4vXr1ktYuYnHMQrMgCCCAgI8AAdAHg18RQAABBBBAAAE3CBAA3bCVWUcEEEAAAQQQQMBHgADog8GvCCCAAAIIIICAGwQIgG7YyqwjAggggAACCCDgI0AA9MHgVwQQQAABBBBAwA0C/w+ELQeExqjNywAAAABJRU5ErkJggg==)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tn1RV-jfOjt1" + }, + "source": [ + "# `benchmark`\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rsmTl5zfwjM3" + }, + "source": [ + "You can try to speed your system by setting `benchmark=True`, which enables cudnn.benchmark. This flag is likely to increase the speed of your system if your input sizes don’t change. This flag makes cudnn auto-tuner look for the optimal set of algorithms for the given hardware configuration. This usually leads to faster runtime.\n", + "But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dWr-OCBgQCeb" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1, benchmark=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qwAvSKYGa24K" + }, + "source": [ + "# `deterministic`\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tl5mfmafwmat" + }, + "source": [ + "PyTorch does not guarantee reproducible results, even when using identical seeds. To guarentee reproducible results, you can remove most of the randomness from your process by setting the `deterministic` flag to True.\n", + "\n", + "Note that it might make your system slower." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mhv5LZ3HbNCK" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1, deterministic=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u_5eJSvTf60f" + }, + "source": [ + "# Exploding and vanishing gradients" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B6drjh4pq6Jv" + }, + "source": [ + "## track_grad_norm\n", + "\n", + "You can debug your grad norm to identify exploding or vanishing gradients using the `track_grad_norm` flag.\n", + "\n", + "Set value to 2 to track the 2-norm. or p to any p-norm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2taHUir8rflR" + }, + "outputs": [], + "source": [ + "# track the 2-norm\n", + "trainer = pl.Trainer(track_grad_norm=2)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3vHKxmruk62f" + }, + "source": [ + "May be set to ‘inf’ infinity-norm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "g7TbD6SxlAjP" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(track_grad_norm='inf')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TcMlRe7ywpe6" + }, + "source": [ + "## Gradient clipping\n", + "\n", + "\n", + "Exploding gradients refer to the problem that the gradients get too large and overflow in training, making the model unstable. Gradient clipping will ‘clip’ the gradients or cap them to a Threshold value to prevent the gradients from getting too large. To avoid this, we can set `gradient_clip_val` (default is set to 0.0).\n", + "\n", + "[when to use it, what are relevant values]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jF9JwmbOgOWF" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gradient_clip_val=0.1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ggb4MkkQrr1h" + }, + "source": [ + "# truncated_bptt_steps\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s1Iu6PyAw9_r" + }, + "source": [ + "If you have a large recurrent model, you can use truncated_bptt_steps flag to split up the backprop over portions of the sequence. This flag will automatically truncate your batches and the trainer will apply Truncated Backprop to it.\n", + "\n", + "Make sure your batches have a sequence dimension.\n", + "\n", + "Lightning takes care of splitting your batch along the time-dimension.\n", + "```\n", + "# we use the second as the time dimension\n", + "# (batch, time, ...)\n", + "sub_batch = batch[0, 0:t, ...]\n", + "Using this feature requires updating your LightningModule’s pytorch_lightning.core.LightningModule.training_step() to include a hiddens arg with the hidden\n", + "\n", + "# Truncated back-propagation through time\n", + "def training_step(self, batch, batch_idx, hiddens):\n", + " # hiddens are the hiddens from the previous truncated backprop step\n", + " out, hiddens = self.lstm(data, hiddens)\n", + "\n", + " return {\n", + " \"loss\": ...,\n", + " \"hiddens\": hiddens # remember to detach() this\n", + " }\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WiTF1VMtruMU" + }, + "outputs": [], + "source": [ + "# backprop every 5 steps in a batch\n", + "trainer = pl.Trainer(truncated_bptt_steps=5)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8XI_kEWkS-nT" + }, + "source": [ + "To modify how the batch is split, override pytorch_lightning.core.LightningModule.tbptt_split_batch():\n", + "\n", + "```\n", + "class LitMNIST(LightningModule):\n", + " def tbptt_split_batch(self, batch, split_size):\n", + " # do your own splitting on the batch\n", + " return splits\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oLbEmbmupwQ8" + }, + "source": [ + "# reload_dataloaders_every_epoch\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CLdNGVv9xD_L" + }, + "source": [ + "Set to True to reload dataloaders every epoch (instead of loading just once in the beginning of training).\n", + "\n", + "```\n", + "# if False (default)\n", + "train_loader = model.train_dataloader()\n", + "for epoch in epochs:\n", + " for batch in train_loader:\n", + " ...\n", + "\n", + "# if True\n", + "for epoch in epochs:\n", + " train_loader = model.train_dataloader()\n", + " for batch in train_loader:\n", + "\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "10AXthXxp311" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(reload_dataloaders_every_epoch=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f513EYl0bmmL" + }, + "source": [ + "# Callbacks\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2pt7iGh4xNs5" + }, + "source": [ + "\n", + "Lightning Callbacks are self-contained programs that can be reused across projects.\n", + "Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run. Lightning includes some a few built-in callbacks that can be used with flags like early stopping and Model Checkpointing, but you can also create your own callbacks to add any functionality to your models.\n", + "\n", + "The callback API includes hooks that allow you to add logic at every point of your training:\n", + "setup, teardown, on_epoch_start, on_epoch_end, on_batch_start, on_batch_end, on_init_start, on_keyboard_interrupt etc. \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1t84gvDNsUuh" + }, + "source": [ + "## callbacks\n", + "\n", + "Use **callbacks=** to pass a list of user defined callbacks. These callbacks DO NOT replace the built-in callbacks (loggers or EarlyStopping). \n", + "\n", + "In this example, we create a dummy callback that prints a message when training starts and ends, using on_train_start and on_train_end hooks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oIXZYabub3f0" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import Callback\n", + "\n", + "class PrintCallback(Callback):\n", + " def on_train_start(self, trainer, pl_module):\n", + " print(\"Training is started!\")\n", + " def on_train_end(self, trainer, pl_module):\n", + " print(\"Training is done.\")\n", + "\n", + "# a list of callbacks\n", + "callbacks = [PrintCallback()]\n", + "trainer = pl.Trainer(callbacks=callbacks)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cNF74CLYfJJu" + }, + "source": [ + "# Model checkpointing\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2blgquBrxLtS" + }, + "source": [ + "Checkpoints capture the exact value of all parameters used by a model.\n", + "\n", + "Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.\n", + "\n", + "Lightning automates saving and loading checkpoints so you restore a training session, saving all the required parameters including: \n", + "* 16-bit scaling factor (apex)\n", + "* Current epoch\n", + "* Global step\n", + "* Model state_dict\n", + "* State of all optimizers\n", + "* State of all learningRate schedulers\n", + "* State of all callbacks\n", + "* The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)\n", + "\n", + "By default Lightning will save a checkpoint in the working directory, which will be updated every epoch.\n", + "\n", + "### Automatic saving\n", + "By default Lightning will save a checkpoint in the end of the first epoch in the working directory, which will be updated every epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGu0JULrg9l7" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(default_root_dir=os.getcwd())\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3s9OjkGuhq1W" + }, + "source": [ + "To change the checkpoint path pass in **default_root_dir=**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DgdxkrIQhvfw" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(default_root_dir='/your/path/to/save/checkpoints')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qyvj_bkWrJiE" + }, + "source": [ + "\n", + "You can also have Lightning update your checkpoint based on a specific metric that you are logging (using self.log), by passing the key to `monitor=`. For example, if we want to save checkpoint based on the validation loss, logged as `val_loss`, you can pass:\n", + "\n", + "\n", + "```\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filepath=os.getcwd(),\n", + " save_top_k=1,\n", + " verbose=True,\n", + " monitor='val_loss',\n", + " mode='min',\n", + " prefix=''\n", + ")\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YzYMivw1rO1O" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "trainer = pl.Trainer(callbacks=[ModelCheckpoint(monitor='val_loss')])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5hYs_FV8iDMn" + }, + "source": [ + "You can modify the behavior of checkpointing by creating your own callback, and passing it to the trainer. \n", + "You can control\n", + "* filepath- where logs are saved\n", + "* save_top_k- save k top models\n", + "* verbose\n", + "* monitor- the metric to monitor\n", + "* mode\n", + "* prefix\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Tb1K2VYDiNTu" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "# DEFAULTS used by the Trainer\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filepath=os.getcwd(),\n", + " save_top_k=3,\n", + " verbose=True,\n", + " monitor='val_loss',\n", + " mode='min',\n", + " prefix='',\n", + ")\n", + "\n", + "trainer = Trainer(callbacks=[checkpoint_callback])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YKhZ6xRojJcl" + }, + "source": [ + "You can disable checkpointing it by passing\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yt8zd2ZFjOXX" + }, + "outputs": [], + "source": [ + "trainer = Trainer(checkpoint_callback=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HcLy8asCjrj9" + }, + "source": [ + "### Manual saving\n", + "\n", + "You can manually save checkpoints and restore your model from the checkpointed state.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kZSkMJf0jR4x" + }, + "outputs": [], + "source": [ + "trainer.fit(model)\n", + "trainer.save_checkpoint(\"example.ckpt\")\n", + "new_model = LitAutoEncoder.load_from_checkpoint(checkpoint_path=\"example.ckpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X2d9cjVPj7CP" + }, + "source": [ + "### Checkpoint Loading\n", + "To load a model along with its weights, biases and module_arguments use following method:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BpAFfg5zkFmH" + }, + "outputs": [], + "source": [ + "model = LitAutoEncoder.load_from_checkpoint(PATH)\n", + "\n", + "print(model.learning_rate)\n", + "# prints the learning_rate you used in this checkpoint\n", + "\n", + "model.eval()\n", + "y_hat = model(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jTQ3mxSJkhFN" + }, + "source": [ + "But if you don’t want to use the values saved in the checkpoint, pass in your own here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IoMcOh9-kfUP" + }, + "outputs": [], + "source": [ + "class LitAutoEncoder(LightningModule):\n", + "\n", + " def __init__(self, in_dim, out_dim):\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + " self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ITPVY8mNknut" + }, + "source": [ + "you can restore the model like this\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "H7XeRJzVkuY8" + }, + "outputs": [], + "source": [ + "# if you train and save the model like this it will use these values when loading\n", + "# the weights. But you can overwrite this\n", + "LitAutoEncoder(in_dim=32, out_dim=10)\n", + "\n", + "# uses in_dim=32, out_dim=10\n", + "model = LitAutoEncoder.load_from_checkpoint(PATH)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "14WwGpnVk0a4" + }, + "outputs": [], + "source": [ + "# uses in_dim=128, out_dim=10\n", + "model = LitAutoEncoder.load_from_checkpoint(PATH, in_dim=128, out_dim=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bY5s6wP_k1CU" + }, + "source": [ + "\n", + "\n", + "## Restoring Training State (resume_from_checkpoint)\n", + "If your training was cut short for some reason, you can resume exactly from where you left off using the `resume_from_checkpoint` flag, which will automatically restore model, epoch, step, LR schedulers, apex, etc..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9zfhHtyrk3rO" + }, + "outputs": [], + "source": [ + "model = LitAutoEncoder()\n", + "trainer = pl.Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')\n", + "\n", + "# automatically restores model, epoch, step, LR schedulers, apex, etc...\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xkKdvALFsmT2" + }, + "source": [ + "## weights_save_path\n", + "You can specify a directory for saving weights file using `weights_save_path`.\n", + "\n", + "(If you are using a custom checkpoint callback, the checkpoint callback will override this flag)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9OwHHFcCsrgT" + }, + "outputs": [], + "source": [ + "# save to your custom path\n", + "trainer = pl.Trainer(weights_save_path='my/path')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PbNtlJ9Wsscf" + }, + "outputs": [], + "source": [ + "# if checkpoint callback used, then overrides the weights path\n", + "# **NOTE: this saves weights to some/path NOT my/path\n", + "checkpoint = ModelCheckpoint(filepath='some/path')\n", + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint],\n", + " weights_save_path='my/path'\n", + ")\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uDdxCuyHdWQt" + }, + "source": [ + "# Early stopping\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fqAy3ihRxTfR" + }, + "source": [ + "The EarlyStopping callback can be used to monitor a validation metric and stop the training when no improvement is observed, to help you avoid overfitting.\n", + "\n", + "To enable Early Stopping you can init the EarlyStopping callback, and pass it to `callbacks=` trainer flag. The callback will look for a logged metric to early stop on.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lFx976CheH93" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "\n", + "trainer = pl.Trainer(callbacks=[EarlyStopping('val_loss')])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MwpJfTvjeOwF" + }, + "source": [ + "You can customize the callback using the following params:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V6I9h6HteK2U" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "\n", + "early_stop_callback = EarlyStopping(\n", + " monitor='val_accuracy',\n", + " min_delta=0.00,\n", + " patience=3,\n", + " verbose=False,\n", + " mode='max'\n", + ")\n", + "trainer = pl.Trainer(callbacks=[early_stop_callback])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7TAIerPYe_Q1" + }, + "source": [ + "The EarlyStopping callback runs at the end of every validation check, which, under the default configuration, happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the Trainer, for example check_val_every_n_epoch and val_check_interval. It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VoKrX2ENh9Fg" + }, + "source": [ + "# Logging" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-CQTPKd7iKLm" + }, + "source": [ + "Lightning has built in integration with various loggers such as TensorBoard, wandb, commet, etc.\n", + "\n", + "\n", + "You can pass any metrics you want to log during training to `self.log`, such as loss or accuracy. Similarly, pass in to self.log any metric you want to log during validation step.\n", + "\n", + "These values will be passed in to the logger of your choise. simply pass in any supported logger to logger trainer flag.\n", + "\n", + "\n", + "\n", + "Use the as`logger=` trainer flag to pass in a Logger, or iterable collection of Loggers, for experiment tracking.\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ty5VPS3AiS8L" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "\n", + "# default logger used by trainer\n", + "logger = TensorBoardLogger(\n", + " save_dir=os.getcwd(),\n", + " version=1,\n", + " name='lightning_logs'\n", + ")\n", + "trainer = pl.Trainer(logger=logger)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jc5oWNpoiuuc" + }, + "source": [ + "Lightning supports the use of multiple loggers, just pass a list to the Trainer.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BlYwMRRyivp_" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger\n", + "logger1 = TensorBoardLogger('tb_logs', name='my_model')\n", + "logger2 = TestTubeLogger('tb_logs', name='my_model')\n", + "trainer = pl.Trainer(logger=[logger1, logger2])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a7EyspQPh7iQ" + }, + "source": [ + "## flush_logs_every_n_steps\n", + "\n", + "Use this flag to determine when logging to disc should happen." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Em_XvsmyiBbk" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(flush_logs_every_n_steps=100)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_vDeKE98qsl1" + }, + "source": [ + "## log_every_n_steps\n", + "How often to add logging rows (does not write to disk)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HkqD7D_0w1Tt" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(log_every_n_steps=1000)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9uw0gfe422CT" + }, + "source": [ + "# info logging" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dQXpt0aatDGo" + }, + "source": [ + "### default_root_dir\n", + "\n", + "---\n", + "\n", + "\n", + "\n", + "Default path for logs and weights when no logger or pytorch_lightning.callbacks.ModelCheckpoint callback passed. On certain clusters you might want to separate where logs and checkpoints are stored. If you don’t then use this argument for convenience. Paths can be local paths or remote paths such as s3://bucket/path or ‘hdfs://path/’. Credentials will need to be set up to use remote filepaths." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CMmID2Bts5W3" + }, + "source": [ + "## weights_summary\n", + "Prints a summary of the weights when training begins. Default is set to `top`- print summary of top level modules.\n", + "\n", + "Options: ‘full’, ‘top’, None." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KTl6EdwDs6j2" + }, + "outputs": [], + "source": [ + "\n", + "# print full summary of all modules and submodules\n", + "trainer = pl.Trainer(weights_summary='full')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R57cSLl9w9ma" + }, + "outputs": [], + "source": [ + "# don't print a summary\n", + "trainer = Trainer(weights_summary=None)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bSc2hU5AotAP" + }, + "source": [ + "# progress bar" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GgvbyDsBxcH6" + }, + "source": [ + "## process_position\n", + "\n", + "Orders the progress bar. Useful when running multiple trainers on the same node.\n", + "\n", + "(This argument is ignored if a custom callback is passed to callbacks)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6ekz8Es8owDn" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(process_position=0)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "itivQFgEphBU" + }, + "source": [ + "## progress_bar_refresh_rate\n", + "\n", + "How often to refresh the progress bar (in steps). In notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates, so raise it to 50 or more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GKe6eVxmplL5" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(progress_bar_refresh_rate=1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8rDHJOJbxNtf" + }, + "outputs": [], + "source": [ + "# disable progress bar\n", + "trainer = Trainer(progress_bar_refresh_rate=0)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NCNvYLwjpWne" + }, + "source": [ + "# profiler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pRknrG_zpY6M" + }, + "outputs": [], + "source": [ + "# to profile standard training events\n", + "trainer = pl.Trainer(profiler=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ji6aWpU73kMM" + }, + "source": [ + "You can also use Lightning AdvancedProfiler if you want more detailed information about time spent in each function call recorded during a given action. The output is quite verbose and you should only use this if you want very detailed reports.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "layG55pt316C" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.profiler import AdvancedProfiler\n", + "\n", + "trainer = Trainer(profiler=AdvancedProfiler())\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "05-trainer-flags-overview.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/06-mnist-tpu-training.ipynb b/notebooks/06-mnist-tpu-training.ipynb new file mode 100644 index 00000000000000..a0dfdceece9b14 --- /dev/null +++ b/notebooks/06-mnist-tpu-training.ipynb @@ -0,0 +1,368 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "06-mnist-tpu-training.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "WsWdLFMVKqbi" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qXO1QLkbRXl0" + }, + "source": [ + "# TPU training with PyTorch Lightning ⚡\n", + "\n", + "In this notebook, we'll train a model on TPUs. Changing one line of code is all you need to that.\n", + "\n", + "The most up to documentation related to TPU training can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/tpu.html).\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", + " - Ask a question on our [GitHub Discussions](https://github.com/PyTorchLightning/pytorch-lightning/discussions/)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UmKX0Qa1RaLL" + }, + "source": [ + "### Setup\n", + "\n", + "Lightning is easy to install. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "vAWOr0FZRaIj" + }, + "source": [ + "! pip install pytorch-lightning -qU" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zepCr1upT4Z3" + }, + "source": [ + "### Install Colab TPU compatible PyTorch/TPU wheels and dependencies" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AYGWh10lRaF1" + }, + "source": [ + "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "SNHa7DpmRZ-C" + }, + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import random_split, DataLoader\n", + "\n", + "# Note - you must have torchvision installed for this example\n", + "from torchvision.datasets import MNIST\n", + "from torchvision import transforms\n", + "\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rjo1dqzGUxt6" + }, + "source": [ + "### Defining The `MNISTDataModule`\n", + "\n", + "Below we define `MNISTDataModule`. You can learn more about datamodules in [docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html) and [datamodule notebook](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "pkbrm3YgUxlE" + }, + "source": [ + "class MNISTDataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './'):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # self.dims is returned when you call dm.size()\n", + " # Setting default dims here because we know them.\n", + " # Could optionally be assigned dynamically in dm.setup()\n", + " self.dims = (1, 28, 28)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nr9AqDWxUxdK" + }, + "source": [ + "### Defining the `LitModel`\n", + "\n", + "Below, we define the model `LitMNIST`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "YKt0KZkOUxVY" + }, + "source": [ + "class LitModel(pl.LightningModule):\n", + " \n", + " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters()\n", + "\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " self.log('train_loss', loss, prog_bar=False)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n", + " return optimizer" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Uxl88z06cHyV" + }, + "source": [ + "### TPU Training\n", + "\n", + "Lightning supports training on a single TPU core or 8 TPU cores.\n", + "\n", + "The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n", + "\n", + "For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting `tpu_cores=[5]` will train on TPU core ID 5." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UZ647Xg2gYng" + }, + "source": [ + "Train on TPU core ID 5 with `tpu_cores=[5]`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bzhJ8g_vUxN2" + }, + "source": [ + "# Init DataModule\n", + "dm = MNISTDataModule()\n", + "# Init model from datamodule's attributes\n", + "model = LitModel(*dm.size(), dm.num_classes)\n", + "# Init trainer\n", + "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n", + "# Train\n", + "trainer.fit(model, dm)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "slMq_0XBglzC" + }, + "source": [ + "Train on single TPU core with `tpu_cores=1`." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "31N5Scf2RZ61" + }, + "source": [ + "# Init DataModule\n", + "dm = MNISTDataModule()\n", + "# Init model from datamodule's attributes\n", + "model = LitModel(*dm.size(), dm.num_classes)\n", + "# Init trainer\n", + "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n", + "# Train\n", + "trainer.fit(model, dm)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_v8xcU5Sf_Cv" + }, + "source": [ + "Train on 8 TPU cores with `tpu_cores=8`. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "EFEw7YpLf-gE" + }, + "source": [ + "# Init DataModule\n", + "dm = MNISTDataModule()\n", + "# Init model from datamodule's attributes\n", + "model = LitModel(*dm.size(), dm.num_classes)\n", + "# Init trainer\n", + "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n", + "# Train\n", + "trainer.fit(model, dm)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m2mhgEgpRZ1g" + }, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ] +} diff --git a/notebooks/07-cifar10-baseline.ipynb b/notebooks/07-cifar10-baseline.ipynb new file mode 100644 index 00000000000000..9f3209a8bbc02c --- /dev/null +++ b/notebooks/07-cifar10-baseline.ipynb @@ -0,0 +1,394 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "07-cifar10-baseline.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "qMDj0BYNECU8" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ECu0zDh8UXU8" + }, + "source": [ + "# PyTorch Lightning CIFAR10 ~94% Baseline Tutorial ⚡\n", + "\n", + "Train a Resnet to 94% accuracy on Cifar10!\n", + "\n", + "Main takeaways:\n", + "1. Experiment with different Learning Rate schedules and frequencies in the configure_optimizers method in pl.LightningModule\n", + "2. Use an existing Resnet architecture with modifications directly with Lightning\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HYpMlx7apuHq" + }, + "source": [ + "### Setup\n", + "Lightning is easy to install. Simply `pip install pytorch-lightning`.\n", + "Also check out [bolts](https://github.com/PyTorchLightning/pytorch-lightning-bolts/) for pre-existing data modules and models." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ziAQCrE-TYWG" + }, + "source": [ + "! pip install pytorch-lightning pytorch-lightning-bolts -qU" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "L-W_Gq2FORoU" + }, + "source": [ + "# Run this if you intend to use TPUs\n", + "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n", + "# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wjov-2N_TgeS" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.optim.lr_scheduler import OneCycleLR\n", + "from torch.optim.swa_utils import AveragedModel, update_bn\n", + "import torchvision\n", + "\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import LearningRateMonitor\n", + "from pytorch_lightning.metrics.functional import accuracy\n", + "from pl_bolts.datamodules import CIFAR10DataModule\n", + "from pl_bolts.transforms.dataset_normalizations import cifar10_normalization" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "54JMU1N-0y0g" + }, + "source": [ + "pl.seed_everything(7);" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FA90qwFcqIXR" + }, + "source": [ + "### CIFAR10 Data Module\n", + "\n", + "Import the existing data module from `bolts` and modify the train and test transforms." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "S9e-W8CSa8nH" + }, + "source": [ + "batch_size = 32\n", + "\n", + "train_transforms = torchvision.transforms.Compose([\n", + " torchvision.transforms.RandomCrop(32, padding=4),\n", + " torchvision.transforms.RandomHorizontalFlip(),\n", + " torchvision.transforms.ToTensor(),\n", + " cifar10_normalization(),\n", + "])\n", + "\n", + "test_transforms = torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " cifar10_normalization(),\n", + "])\n", + "\n", + "cifar10_dm = CIFAR10DataModule(\n", + " batch_size=batch_size,\n", + " train_transforms=train_transforms,\n", + " test_transforms=test_transforms,\n", + " val_transforms=test_transforms,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SfCsutp3qUMc" + }, + "source": [ + "### Resnet\n", + "Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GNSeJgwvhHp-" + }, + "source": [ + "def create_model():\n", + " model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n", + " model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " model.maxpool = nn.Identity()\n", + " return model" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HUCj5TKsqty1" + }, + "source": [ + "### Lightning Module\n", + "Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#configure-optimizers) method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "03OMrBa5iGtT" + }, + "source": [ + "class LitResnet(pl.LightningModule):\n", + " def __init__(self, lr=0.05):\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters()\n", + " self.model = create_model()\n", + "\n", + " def forward(self, x):\n", + " out = self.model(x)\n", + " return F.log_softmax(out, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = F.log_softmax(self.model(x), dim=1)\n", + " loss = F.nll_loss(logits, y)\n", + " self.log('train_loss', loss)\n", + " return loss\n", + "\n", + " def evaluate(self, batch, stage=None):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + "\n", + " if stage:\n", + " self.log(f'{stage}_loss', loss, prog_bar=True)\n", + " self.log(f'{stage}_acc', acc, prog_bar=True)\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " self.evaluate(batch, 'val')\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " self.evaluate(batch, 'test')\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", + " steps_per_epoch = 45000 // batch_size\n", + " scheduler_dict = {\n", + " 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),\n", + " 'interval': 'step',\n", + " }\n", + " return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3FFPgpAFi9KU" + }, + "source": [ + "model = LitResnet(lr=0.05)\n", + "model.datamodule = cifar10_dm\n", + "\n", + "trainer = pl.Trainer(\n", + " progress_bar_refresh_rate=20,\n", + " max_epochs=40,\n", + " gpus=1,\n", + " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),\n", + " callbacks=[LearningRateMonitor(logging_interval='step')],\n", + ")\n", + "\n", + "trainer.fit(model, cifar10_dm)\n", + "trainer.test(model, datamodule=cifar10_dm);" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lWL_WpeVIXWQ" + }, + "source": [ + "### Bonus: Use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to get a boost on performance\n", + "\n", + "Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning:\n", + "- Use `training_epoch_end` to run code after the end of every epoch\n", + "- Use a pretrained model directly with this wrapper for SWA" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bsSwqKv0t9uY" + }, + "source": [ + "class SWAResnet(LitResnet):\n", + " def __init__(self, trained_model, lr=0.01):\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters('lr')\n", + " self.model = trained_model\n", + " self.swa_model = AveragedModel(self.model)\n", + "\n", + " def forward(self, x):\n", + " out = self.swa_model(x)\n", + " return F.log_softmax(out, dim=1)\n", + "\n", + " def training_epoch_end(self, training_step_outputs):\n", + " self.swa_model.update_parameters(self.model)\n", + "\n", + " def validation_step(self, batch, batch_idx, stage=None):\n", + " x, y = batch\n", + " logits = F.log_softmax(self.model(x), dim=1)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + "\n", + " self.log(f'val_loss', loss, prog_bar=True)\n", + " self.log(f'val_acc', acc, prog_bar=True)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", + " return optimizer\n", + "\n", + " def on_train_end(self):\n", + " update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "cA6ZG7C74rjL" + }, + "source": [ + "swa_model = SWAResnet(model.model, lr=0.01)\n", + "swa_model.datamodule = cifar10_dm\n", + "\n", + "swa_trainer = pl.Trainer(\n", + " progress_bar_refresh_rate=20,\n", + " max_epochs=20,\n", + " gpus=1,\n", + " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),\n", + ")\n", + "\n", + "swa_trainer.fit(swa_model, cifar10_dm)\n", + "swa_trainer.test(swa_model, datamodule=cifar10_dm);" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "RRHMfGiDpZ2M" + }, + "source": [ + "# Start tensorboard.\n", + "%reload_ext tensorboard\n", + "%tensorboard --logdir lightning_logs/" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RltpFGS-s0M1" + }, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ] +} diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 00000000000000..a72e154c364105 --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1,15 @@ +# Lightning Notebooks ⚡ + +## Official Notebooks + +You can easily run any of the official notebooks by clicking the 'Open in Colab' links in the table below :smile: + +| Notebook | Description | Colab Link | +| :----------------------- | :----------------------------------------------------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| **MNIST Hello World** | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb) | +| **Datamodules** | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb) | +| **GAN** | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb) | +| **BERT** | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb) | +| **Trainer Flags** | Overview of the available Lightning `Trainer` flags | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/05-trainer-flags-overview.ipynb) | +| **TPU Training** | Train a model on MNIST using TPUs with Lightning | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb) | +| **94% Baseline CIFAR10** | Establish a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/07-cifar10-baseline.ipynb) | diff --git a/pl_examples/README.md b/pl_examples/README.md index 93715b0e44661f..bed553322edf31 100644 --- a/pl_examples/README.md +++ b/pl_examples/README.md @@ -1,67 +1,19 @@ -# Examples -This folder has 3 sections: - -## Basic Examples -Use these examples to test how lightning works. - -#### Test on CPU -```bash -python cpu_template.py -``` - ---- -#### Train on a single GPU -```bash -python gpu_template.py --gpus 1 -``` - ---- -#### DataParallel (dp) -Train on multiple GPUs using DataParallel. - -```bash -python gpu_template.py --gpus 2 --distributed_backend dp -``` - ---- -#### DistributedDataParallel (ddp) - -Train on multiple GPUs using DistributedDataParallel -```bash -python gpu_template.py --gpus 2 --distributed_backend ddp -``` +# Examples +Our most robust examples showing all sorts of implementations +can be found in our sister library [PyTorch-Lightning-Bolts](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2). --- -#### DistributedDataParallel+DP (ddp2) - -Train on multiple GPUs using DistributedDataParallel + dataparallel. -On a single node, uses all GPUs for 1 model. Then shares gradient information -across nodes. -```bash -python gpu_template.py --gpus 2 --distributed_backend ddp2 -``` -## Multi-node example +## Basic examples +In this folder we add 3 simple examples: -This demo launches a job using 2 GPUs on 2 different nodes (4 GPUs total). -To run this demo do the following: +* [MNIST Classifier](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py) (defines the model inside the `LightningModule`). +* [Image Classifier](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/backbone_image_classifier.py) (trains arbitrary datasets with arbitrary backbones). +* [Autoencoder](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/autoencoder.py) (shows how the `LightningModule` can be used as a system) -1. Log into the jumphost node of your SLURM-managed cluster. -2. Create a conda environment with Lightning and a GPU PyTorch version. -3. Choose a script to submit - -### DDP -Submit this job to run with DistributedDataParallel (2 nodes, 2 gpus each) -```bash -sbatch ddp_job_submit.sh YourEnv -``` - -### DDP2 -Submit this job to run with a different implementation of DistributedDataParallel. -In this version, each node acts like DataParallel but syncs across nodes like DDP. -```bash -sbatch ddp2_job_submit.sh YourEnv -``` +--- -## Domain templates -These are templates to show common approaches such as GANs and RL. +## Domain examples +This folder contains older examples. You should instead use the examples +in [PyTorch-Lightning-Bolts](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2) +for advanced use cases. diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index 1c5908539cfdc6..150ac309ddcebd 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -1,147 +1,62 @@ +import os +from urllib.error import HTTPError + +from six.moves import urllib + +from pytorch_lightning.utilities import _module_available + +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) + +_EXAMPLES_ROOT = os.path.dirname(__file__) +_PACKAGE_ROOT = os.path.dirname(_EXAMPLES_ROOT) +_DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') + +_TORCHVISION_AVAILABLE = _module_available("torchvision") +_TORCHVISION_MNIST_AVAILABLE = not bool(os.environ.get("PL_USE_MOCKED_MNIST", False)) +_DALI_AVAILABLE = _module_available("nvidia.dali") + +if _TORCHVISION_MNIST_AVAILABLE: + try: + from torchvision.datasets.mnist import MNIST + MNIST(_DATASETS_PATH, download=True) + except HTTPError: + _TORCHVISION_MNIST_AVAILABLE = False + +LIGHTNING_LOGO = """ + #### + ########### + #################### + ############################ + ##################################### +############################################## +######################### ################### +####################### ################### +#################### #################### +################## ##################### +################ ###################### +##################### ################# +###################### ################### +##################### ##################### +#################### ####################### +################### ######################### +############################################## + ##################################### + ############################ + #################### + ########## + #### """ -Template model definition -------------------------- -In 99% of cases you want to just copy `one of the examples -`_ -to start a new lightningModule and change the core of what your model is actually trying to do. -.. code-block:: bash +def nice_print(msg, last=False): + print() + print("\033[0;35m" + msg + "\033[0m") + if last: + print() - # get a copy of the module template - wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/pl_examples/new_project_templates/lightning_module_template.py # noqa: E501 - -Trainer Example ---------------- - -**`__main__` function** - -Normally, we want to let the `__main__` function start the training. - Inside the main we parse training arguments with whatever hyperparameters we want. - Your LightningModule will have a chance to add hyperparameters. - -.. code-block:: python - - from test_tube import HyperOptArgumentParser - - if __name__ == '__main__': - - # use default args given by lightning - root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0] - parent_parser = HyperOptArgumentParser(strategy='random_search', add_help=False) - add_default_args(parent_parser, root_dir) - - # allow model to overwrite or extend args - parser = ExampleModel.add_model_specific_args(parent_parser) - hyperparams = parser.parse_args() - - # train model - main(hyperparams) - -**Main Function** - -The main function is your entry into the program. This is where you init your model, checkpoint directory, - and launch the training. The main function should have 3 arguments: - -- hparams: a configuration of hyperparameters. -- slurm_manager: Slurm cluster manager object (can be None) -- dict: for you to return any values you want (useful in meta-learning, otherwise set to) - -.. code-block:: python - - def main(hparams, cluster, results_dict): - # build model - model = MyLightningModule(hparams) - - # configure trainer - trainer = Trainer() - - # train model - trainer.fit(model) - - -The `__main__` function will start training on your **main** function. - If you use the HyperParameterOptimizer in hyper parameter optimization mode, - this main function will get one set of hyperparameters. If you use it as a simple - argument parser you get the default arguments in the argument parser. - -So, calling main(hyperparams) runs the model with the default argparse arguments.:: - - main(hyperparams) - - -CPU hyperparameter search -------------------------- - -.. code-block:: python - - # run a grid search over 20 hyperparameter combinations. - hyperparams.optimize_parallel_cpu( - main_local, - nb_trials=20, - nb_workers=1 - ) - - -Hyperparameter search on a single or multiple GPUs --------------------------------------------------- - -.. code-block:: python - - # run a grid search over 20 hyperparameter combinations. - hyperparams.optimize_parallel_gpu( - main_local, - nb_trials=20, - nb_workers=1, - gpus=[0,1,2,3] - ) - - -Hyperparameter search on a SLURM HPC cluster --------------------------------------------- - -.. code-block:: python - - def optimize_on_cluster(hyperparams): - # enable cluster training - cluster = SlurmCluster( - hyperparam_optimizer=hyperparams, - log_path=hyperparams.tt_save_path, - test_tube_exp_name=hyperparams.tt_name - ) - - # email for cluster coms - cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True) - - # configure cluster - cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus - cluster.job_time = '48:00:00' - cluster.gpu_type = '1080ti' - cluster.memory_mb_per_node = 48000 - - # any modules for code to run in env - cluster.add_command('source activate pytorch_lightning') - - # name of exp - job_display_name = hyperparams.tt_name.split('_')[0] - job_display_name = job_display_name[0:3] - - # run hopt - logging.info('submitting jobs...') - cluster.optimize_parallel_cluster_gpu( - main, - nb_trials=hyperparams.nb_hopt_trials, - job_name=job_display_name - ) - - # run cluster hyperparameter search - optimize_on_cluster(hyperparams) - -""" - -from pl_examples.models.lightning_template import LightningTemplateModel - -__all__ = [ - 'LightningTemplateModel' -] +def cli_lightning_logo(): + nice_print(LIGHTNING_LOGO) diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index 63fdc7f8c47c71..b02ea21c7940dd 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -1,62 +1,73 @@ -## Basic Examples -Use these examples to test how lightning works. +## Basic Examples +Use these examples to test how lightning works. -#### Test on CPU +#### MNIST +Trains MNIST where the model is defined inside the `LightningModule`. ```bash -python cpu_template.py -``` +# cpu +python simple_image_classifier.py ---- -#### Train on a single GPU -```bash -python gpu_template.py --gpus 1 -``` +# gpus (any number) +python simple_image_classifier.py --gpus 2 ---- -#### DataParallel (dp) -Train on multiple GPUs using DataParallel. +# dataparallel +python simple_image_classifier.py --gpus 2 --distributed_backend 'dp' +``` +--- +#### MNIST with DALI +The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI). +Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html). ```bash -python gpu_template.py --gpus 2 --distributed_backend dp -``` +python dali_image_classifier.py +``` --- -#### DistributedDataParallel (ddp) - -Train on multiple GPUs using DistributedDataParallel +#### Image classifier +Generic image classifier with an arbitrary backbone (ie: a simple system) ```bash -python gpu_template.py --gpus 2 --distributed_backend ddp +# cpu +python backbone_image_classifier.py + +# gpus (any number) +python backbone_image_classifier.py --gpus 2 + +# dataparallel +python backbone_image_classifier.py --gpus 2 --distributed_backend 'dp' ``` --- -#### DistributedDataParallel+DP (ddp2) - -Train on multiple GPUs using DistributedDataParallel + DataParallel. -On a single node, uses all GPUs for 1 model. Then shares gradient information -across nodes. +#### Autoencoder +Showing the power of a system... arbitrarily complex training loops ```bash -python gpu_template.py --gpus 2 --distributed_backend ddp2 -``` +# cpu +python autoencoder.py +# gpus (any number) +python autoencoder.py --gpus 2 -# Multi-node example +# dataparallel +python autoencoder.py --gpus 2 --distributed_backend 'dp' +``` +--- +# Multi-node example This demo launches a job using 2 GPUs on 2 different nodes (4 GPUs total). To run this demo do the following: -1. Log into the jumphost node of your SLURM-managed cluster. -2. Create a conda environment with Lightning and a GPU PyTorch version. -3. Choose a script to submit +1. Log into the jumphost node of your SLURM-managed cluster. +2. Create a conda environment with Lightning and a GPU PyTorch version. +3. Choose a script to submit -#### DDP +#### DDP Submit this job to run with DistributedDataParallel (2 nodes, 2 gpus each) ```bash -sbatch ddp_job_submit.sh YourEnv +sbatch submit_ddp_job.sh YourEnv ``` -#### DDP2 +#### DDP2 Submit this job to run with a different implementation of DistributedDataParallel. In this version, each node acts like DataParallel but syncs across nodes like DDP. ```bash -sbatch ddp2_job_submit.sh YourEnv +sbatch submit_ddp2_job.sh YourEnv ``` diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py new file mode 100644 index 00000000000000..6841b8555ef1fc --- /dev/null +++ b/pl_examples/basic_examples/autoencoder.py @@ -0,0 +1,132 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader, random_split + +import pytorch_lightning as pl +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + + +class LitAutoEncoder(pl.LightningModule): + """ + >>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitAutoEncoder( + (encoder): ... + (decoder): ... + ) + """ + + def __init__(self, hidden_dim: int = 64): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(28 * 28, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 3), + ) + self.decoder = nn.Sequential( + nn.Linear(3, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 28 * 28), + ) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + + def training_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + self.log('valid_loss', loss, on_step=True) + + def test_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + self.log('test_loss', loss, on_step=True) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser.add_argument('--hidden_dim', type=int, default=64) + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + train_loader = DataLoader(mnist_train, batch_size=args.batch_size) + val_loader = DataLoader(mnist_val, batch_size=args.batch_size) + test_loader = DataLoader(mnist_test, batch_size=args.batch_size) + + # ------------ + # model + # ------------ + model = LitAutoEncoder(args.hidden_dim) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + result = trainer.test(test_dataloaders=test_loader) + print(result) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py new file mode 100644 index 00000000000000..1c78d264a86816 --- /dev/null +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -0,0 +1,145 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split + +import pytorch_lightning as pl +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + + +class Backbone(torch.nn.Module): + """ + >>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Backbone( + (l1): Linear(...) + (l2): Linear(...) + ) + """ + + def __init__(self, hidden_dim=128): + super().__init__() + self.l1 = torch.nn.Linear(28 * 28, hidden_dim) + self.l2 = torch.nn.Linear(hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + +class LitClassifier(pl.LightningModule): + """ + >>> LitClassifier(Backbone()) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitClassifier( + (backbone): ... + ) + """ + + def __init__(self, backbone, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + self.backbone = backbone + + def forward(self, x): + # use forward for inference/predictions + embedding = self.backbone(x) + return embedding + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss, on_step=True) + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + # self.hparams available because we called self.save_hyperparameters() + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("LitClassifier") + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parent_parser + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser.add_argument('--hidden_dim', type=int, default=128) + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + train_loader = DataLoader(mnist_train, batch_size=args.batch_size) + val_loader = DataLoader(mnist_val, batch_size=args.batch_size) + test_loader = DataLoader(mnist_test, batch_size=args.batch_size) + + # ------------ + # model + # ------------ + model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + result = trainer.test(test_dataloaders=test_loader) + print(result) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py new file mode 100644 index 00000000000000..f3d9469144f501 --- /dev/null +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -0,0 +1,226 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +Example script of running the experimental DDP Sequential Plugin. +This script splits a convolutional model onto multiple GPUs, whilst using the internal built in balancer +to balance across your GPUs. + +To run: +python conv_model_sequential_example.py --accelerator ddp --gpus 4 --max_epochs 1 --batch_size 256 --use_rpc_sequential +""" +import math +from argparse import ArgumentParser + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from torchmetrics.functional import accuracy + +import pytorch_lightning as pl +from pl_examples import cli_lightning_logo +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import RPCSequentialPlugin +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE + +if _BOLTS_AVAILABLE: + import pl_bolts + from pl_bolts.transforms.dataset_normalizations import cifar10_normalization + +##################### +# Modules # +##################### + + +class Flatten(nn.Module): + + def forward(self, x): + return x.view(x.size(0), -1) + + +############################### +# LightningModule # +############################### + + +class LitResnet(pl.LightningModule): + """ + >>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitResnet( + (sequential_module): Sequential(...) + ) + """ + + def __init__(self, lr=0.05, batch_size=32, manual_optimization=False): + super().__init__() + + self.save_hyperparameters() + self.sequential_module = nn.Sequential( + # Conv Layer block 1 + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), + nn.ReLU(inplace=False), + nn.MaxPool2d(kernel_size=2, stride=2), + + # Conv Layer block 2 + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), + nn.ReLU(inplace=False), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Dropout2d(p=0.05), + + # Conv Layer block 3 + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), + nn.ReLU(inplace=False), + nn.MaxPool2d(kernel_size=2, stride=2), + Flatten(), + nn.Dropout(p=0.1), + nn.Linear(4096, 1024), + nn.ReLU(inplace=False), + nn.Linear(1024, 512), + nn.ReLU(inplace=False), + nn.Dropout(p=0.1), + nn.Linear(512, 10) + ) + self._example_input_array = torch.randn((1, 3, 32, 32)) + + if manual_optimization: + self.automatic_optimization = False + self.training_step = self.training_step_manual + + def forward(self, x): + out = self.sequential_module(x) + return F.log_softmax(out, dim=-1) + + def training_step_manual(self, batch, batch_idx): + opt = self.optimizers() + + def closure(): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y) + self.manual_backward(loss, opt) + self.log('train_loss', loss, prog_bar=True) + + opt.step(closure=closure) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y) + self.log('Training Loss', loss) + return loss + + def _evaluate(self, batch, batch_idx, stage=None): + x, y = batch + out = self.forward(x) + logits = F.log_softmax(out, dim=-1) + loss = F.nll_loss(logits, y) + preds = torch.argmax(logits, dim=-1) + acc = accuracy(preds, y) + + if stage: + self.log(f'{stage}_loss', loss, prog_bar=True) + self.log(f'{stage}_acc', acc, prog_bar=True) + + return loss, acc + + def validation_step(self, batch, batch_idx): + return self._evaluate(batch, batch_idx, 'val')[0] + + def test_step(self, batch, batch_idx): + loss, acc = self._evaluate(batch, batch_idx, 'test') + self.log_dict({'test_loss': loss, 'test_acc': acc}) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.OneCycleLR( + optimizer, + 0.1, + epochs=self.trainer.max_epochs, + steps_per_epoch=math.ceil(45000 / self.hparams.batch_size) + ), + 'interval': 'step', + } + } + + +################################# +# Instantiate Data Module # +################################# + + +def instantiate_datamodule(args): + train_transforms = torchvision.transforms.Compose([ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + cifar10_normalization(), + ]) + + test_transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + cifar10_normalization(), + ]) + + cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule( + data_dir=args.data_dir, + batch_size=args.batch_size, + train_transforms=train_transforms, + test_transforms=test_transforms, + val_transforms=test_transforms, + ) + + return cifar10_dm + + +if __name__ == "__main__": + cli_lightning_logo() + + assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts" + assert _FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example." + + parser = ArgumentParser(description="Pipe Example") + parser.add_argument("--use_rpc_sequential", action="store_true") + parser.add_argument("--manual_optimization", action="store_true") + parser = Trainer.add_argparse_args(parser) + parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser) + args = parser.parse_args() + + cifar10_dm = instantiate_datamodule(args) + + plugins = None + if args.use_rpc_sequential: + plugins = RPCSequentialPlugin() + + model = LitResnet(batch_size=args.batch_size, manual_optimization=args.manual_optimization) + + trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None) + trainer.fit(model, cifar10_dm) + trainer.test(model, datamodule=cifar10_dm) + + if trainer.accelerator.rpc_enabled: + # Called at the end of trainer to ensure all processes are killed + trainer.training_type_plugin.exit_rpc_process() diff --git a/pl_examples/basic_examples/cpu_template.py b/pl_examples/basic_examples/cpu_template.py deleted file mode 100644 index 5929a07be57272..00000000000000 --- a/pl_examples/basic_examples/cpu_template.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Runs a model on the CPU on a single node. -""" -import os -from argparse import ArgumentParser - -import numpy as np -import torch - -import pytorch_lightning as pl -from pl_examples.models.lightning_template import LightningTemplateModel - -SEED = 2334 -torch.manual_seed(SEED) -np.random.seed(SEED) - - -def main(hparams): - """ - Main training routine specific for this project - :param hparams: - """ - # ------------------------ - # 1 INIT LIGHTNING MODEL - # ------------------------ - model = LightningTemplateModel(hparams) - - # ------------------------ - # 2 INIT TRAINER - # ------------------------ - trainer = pl.Trainer(max_epochs=hparams.epochs, overfit_pct=0.01, early_stop_callback=True) - - # ------------------------ - # 3 START TRAINING - # ------------------------ - trainer.fit(model) - - -if __name__ == '__main__': - # ------------------------ - # TRAINING ARGUMENTS - # ------------------------ - # these are project-wide arguments - root_dir = os.path.dirname(os.path.realpath(__file__)) - parent_parser = ArgumentParser(add_help=False) - - # each LightningModule defines arguments relevant to it - parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir) - hyperparams = parser.parse_args() - - # --------------------- - # RUN TRAINING - # --------------------- - main(hyperparams) diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py new file mode 100644 index 00000000000000..08bf64da252bf7 --- /dev/null +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -0,0 +1,238 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from argparse import ArgumentParser +from distutils.version import LooseVersion +from random import shuffle +from warnings import warn + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data import random_split + +import pytorch_lightning as pl +from pl_examples import ( + _DALI_AVAILABLE, + _DATASETS_PATH, + _TORCHVISION_AVAILABLE, + _TORCHVISION_MNIST_AVAILABLE, + cli_lightning_logo, +) + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + +if _DALI_AVAILABLE: + from nvidia.dali import __version__ as dali_version + from nvidia.dali import ops + from nvidia.dali.pipeline import Pipeline + from nvidia.dali.plugin.pytorch import DALIClassificationIterator + + NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0') + if NEW_DALI_API: + from nvidia.dali.plugin.base_iterator import LastBatchPolicy +else: + warn('NVIDIA DALI is not available') + ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC + + +class ExternalMNISTInputIterator(object): + """ + This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches + """ + + def __init__(self, mnist_ds, batch_size): + self.batch_size = batch_size + self.mnist_ds = mnist_ds + self.indices = list(range(len(self.mnist_ds))) + shuffle(self.indices) + + def __iter__(self): + self.i = 0 + self.n = len(self.mnist_ds) + return self + + def __next__(self): + batch = [] + labels = [] + for _ in range(self.batch_size): + index = self.indices[self.i] + img, label = self.mnist_ds[index] + batch.append(img.numpy()) + labels.append(np.array([label], dtype=np.uint8)) + self.i = (self.i + 1) % self.n + return (batch, labels) + + +class ExternalSourcePipeline(Pipeline): + """ + This DALI pipeline class just contains the MNIST iterator + """ + + def __init__(self, batch_size, eii, num_threads, device_id): + super(ExternalSourcePipeline, self).__init__(batch_size, num_threads, device_id, seed=12) + self.source = ops.ExternalSource(source=eii, num_outputs=2) + self.build() + + def define_graph(self): + images, labels = self.source() + return images, labels + + +class DALIClassificationLoader(DALIClassificationIterator): + """ + This class extends DALI's original `DALIClassificationIterator` with the `__len__()` function + so that we can call `len()` on it + """ + + def __init__( + self, + pipelines, + size=-1, + reader_name=None, + auto_reset=False, + fill_last_batch=True, + dynamic_shape=False, + last_batch_padded=False, + ): + if NEW_DALI_API: + last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP + super().__init__( + pipelines, + size, + reader_name, + auto_reset, + dynamic_shape, + last_batch_policy=last_batch_policy, + last_batch_padded=last_batch_padded + ) + else: + super().__init__( + pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded + ) + self._fill_last_batch = fill_last_batch + + def __len__(self): + batch_count = self._size // (self._num_gpus * self.batch_size) + last_batch = 1 if self._fill_last_batch else 1 + return batch_count + last_batch + + +class LitClassifier(pl.LightningModule): + + def __init__(self, hidden_dim=128, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def split_batch(self, batch): + return batch[0]["data"], batch[0]["label"].squeeze().long() + + def training_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("LitClassifier") + parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parent_parser + + +def cli_main(): + if not _DALI_AVAILABLE: + return + + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size) + eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size) + eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size) + + pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0) + train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True) + + pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0) + val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False) + + pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0) + test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False) + + # ------------ + # model + # ------------ + model = LitClassifier(args.hidden_dim, args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + trainer.test(test_dataloaders=test_loader) + + +if __name__ == "__main__": + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/gpu_template.py b/pl_examples/basic_examples/gpu_template.py deleted file mode 100644 index c5fa94a3cf1408..00000000000000 --- a/pl_examples/basic_examples/gpu_template.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Runs a model on a single node across multiple gpus. -""" -import os -from argparse import ArgumentParser - -import numpy as np -import torch - -import pytorch_lightning as pl -from pl_examples.models.lightning_template import LightningTemplateModel - -SEED = 2334 -torch.manual_seed(SEED) -np.random.seed(SEED) - - -def main(hparams): - """ - Main training routine specific for this project - :param hparams: - """ - # ------------------------ - # 1 INIT LIGHTNING MODEL - # ------------------------ - model = LightningTemplateModel(hparams) - - # ------------------------ - # 2 INIT TRAINER - # ------------------------ - trainer = pl.Trainer( - max_epochs=hparams.epochs, - gpus=hparams.gpus, - distributed_backend=hparams.distributed_backend, - precision=16 if hparams.use_16bit else 32, - ) - - # ------------------------ - # 3 START TRAINING - # ------------------------ - trainer.fit(model) - - -if __name__ == '__main__': - # ------------------------ - # TRAINING ARGUMENTS - # ------------------------ - # these are project-wide arguments - - root_dir = os.path.dirname(os.path.realpath(__file__)) - parent_parser = ArgumentParser(add_help=False) - - # gpu args - parent_parser.add_argument( - '--gpus', - type=int, - default=2, - help='how many gpus' - ) - parent_parser.add_argument( - '--distributed_backend', - type=str, - default='dp', - help='supports three options dp, ddp, ddp2' - ) - parent_parser.add_argument( - '--use_16bit', - dest='use_16bit', - action='store_true', - help='if true uses 16 bit precision' - ) - - # each LightningModule defines arguments relevant to it - parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir) - hyperparams = parser.parse_args() - - # --------------------- - # RUN TRAINING - # --------------------- - main(hyperparams) diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py new file mode 100644 index 00000000000000..ea64f96c05d7d6 --- /dev/null +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import platform +from typing import Optional +from warnings import warn + +from torch.utils.data import DataLoader, random_split + +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE +from pytorch_lightning import LightningDataModule + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + + +class MNISTDataModule(LightningDataModule): + """ + Standard MNIST, train, val, test splits and transforms + + >>> MNISTDataModule() # doctest: +ELLIPSIS + <...mnist_datamodule.MNISTDataModule object at ...> + """ + + name = "mnist" + + def __init__( + self, + data_dir: str = _DATASETS_PATH, + val_split: int = 5000, + num_workers: int = 16, + normalize: bool = False, + seed: int = 42, + batch_size: int = 32, + *args, + **kwargs, + ): + """ + Args: + data_dir: where to save/load the data + val_split: how many of the training images to use for the validation split + num_workers: how many workers to use for loading data + normalize: If true applies image normalize + seed: starting seed for RNG. + batch_size: desired batch size. + """ + super().__init__(*args, **kwargs) + if num_workers and platform.system() == "Windows": + # see: https://stackoverflow.com/a/59680818 + warn( + f"You have requested num_workers={num_workers} on Windows," + " but currently recommended is 0, so we set it for you" + ) + num_workers = 0 + + self.dims = (1, 28, 28) + self.data_dir = data_dir + self.val_split = val_split + self.num_workers = num_workers + self.normalize = normalize + self.seed = seed + self.batch_size = batch_size + self.dataset_train = ... + self.dataset_val = ... + self.test_transforms = self.default_transforms + + @property + def num_classes(self): + return 10 + + def prepare_data(self): + """Saves MNIST files to `data_dir`""" + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage: Optional[str] = None): + """Split the train and valid dataset""" + extra = dict(transform=self.default_transforms) if self.default_transforms else {} + dataset = MNIST(self.data_dir, train=True, download=False, **extra) + train_length = len(dataset) + self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split]) + + def train_dataloader(self): + """MNIST train set removes a subset to use for validation""" + loader = DataLoader( + self.dataset_train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + drop_last=True, + pin_memory=True, + ) + return loader + + def val_dataloader(self): + """MNIST val set uses a subset of the training set for validation""" + loader = DataLoader( + self.dataset_val, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + drop_last=True, + pin_memory=True, + ) + return loader + + def test_dataloader(self): + """MNIST test set uses the test split""" + extra = dict(transform=self.test_transforms) if self.test_transforms else {} + dataset = MNIST(self.data_dir, train=False, download=False, **extra) + loader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + drop_last=True, + pin_memory=True, + ) + return loader + + @property + def default_transforms(self): + if not _TORCHVISION_AVAILABLE: + return None + if self.normalize: + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) + ]) + else: + mnist_transforms = transform_lib.ToTensor() + + return mnist_transforms diff --git a/pl_examples/basic_examples/multi_node_ddp2_demo.py b/pl_examples/basic_examples/multi_node_ddp2_demo.py deleted file mode 100644 index ca9c986a17f259..00000000000000 --- a/pl_examples/basic_examples/multi_node_ddp2_demo.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Multi-node example (GPU) -""" -import os -from argparse import ArgumentParser - -import numpy as np -import torch - -import pytorch_lightning as pl -from pl_examples.models.lightning_template import LightningTemplateModel - -SEED = 2334 -torch.manual_seed(SEED) -np.random.seed(SEED) - - -def main(hparams): - """Main training routine specific for this project.""" - # ------------------------ - # 1 INIT LIGHTNING MODEL - # ------------------------ - model = LightningTemplateModel(hparams) - - # ------------------------ - # 2 INIT TRAINER - # ------------------------ - trainer = pl.Trainer( - gpus=2, - num_nodes=2, - distributed_backend='ddp2' - ) - - # ------------------------ - # 3 START TRAINING - # ------------------------ - trainer.fit(model) - - -if __name__ == '__main__': - root_dir = os.path.dirname(os.path.realpath(__file__)) - parent_parser = ArgumentParser(add_help=False) - - # each LightningModule defines arguments relevant to it - parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir) - hyperparams = parser.parse_args() - - # --------------------- - # RUN TRAINING - # --------------------- - main(hyperparams) diff --git a/pl_examples/basic_examples/multi_node_ddp_demo.py b/pl_examples/basic_examples/multi_node_ddp_demo.py deleted file mode 100644 index 518a9f39cc938a..00000000000000 --- a/pl_examples/basic_examples/multi_node_ddp_demo.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Multi-node example (GPU) -""" -import os -from argparse import ArgumentParser - -import numpy as np -import torch - -import pytorch_lightning as pl -from pl_examples.models.lightning_template import LightningTemplateModel - -SEED = 2334 -torch.manual_seed(SEED) -np.random.seed(SEED) - - -def main(hparams): - """Main training routine specific for this project.""" - # ------------------------ - # 1 INIT LIGHTNING MODEL - # ------------------------ - model = LightningTemplateModel(hparams) - - # ------------------------ - # 2 INIT TRAINER - # ------------------------ - trainer = pl.Trainer( - gpus=2, - num_nodes=2, - distributed_backend='ddp' - ) - - # ------------------------ - # 3 START TRAINING - # ------------------------ - trainer.fit(model) - - -if __name__ == '__main__': - root_dir = os.path.dirname(os.path.realpath(__file__)) - parent_parser = ArgumentParser(add_help=False) - - # each LightningModule defines arguments relevant to it - parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir) - hyperparams = parser.parse_args() - - # --------------------- - # RUN TRAINING - # --------------------- - main(hyperparams) diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py new file mode 100644 index 00000000000000..ca640a96f9588b --- /dev/null +++ b/pl_examples/basic_examples/profiler_example.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will generate 2 traces: one for `training_step` and one for `validation_step`. +The traces can be visualized in 2 ways: +* With Chrome: + 1. Open Chrome and copy/paste this url: `chrome://tracing/`. + 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. +* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) + 1. pip install tensorboard torch-tb-profiler + 2. tensorboard --logdir={FOLDER} +""" + +import sys +from argparse import ArgumentParser + +import torch +import torchvision +import torchvision.models as models +import torchvision.transforms as T + +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningDataModule, LightningModule, Trainer + +DEFAULT_CMD_LINE = ( + "--max_epochs", + "1", + "--limit_train_batches", + "15", + "--limit_val_batches", + "15", + "--profiler", + "pytorch", + "--gpus", + f"{int(torch.cuda.is_available())}", +) + + +class ModelToProfile(LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + self.criterion = torch.nn.CrossEntropyLoss() + + def training_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("val_loss", loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + +class CIFAR10DataModule(LightningDataModule): + + transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) + + def train_dataloader(self, *args, **kwargs): + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform) + return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + + def val_dataloader(self, *args, **kwargs): + valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform) + return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) + + +def cli_main(): + + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE + args = parser.parse_args(args=cmd_line) + + model = ModelToProfile(models.resnet50(pretrained=True)) + datamodule = CIFAR10DataModule() + trainer = Trainer(**vars(args)) + trainer.fit(model, datamodule=datamodule) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py new file mode 100644 index 00000000000000..3f7079d665ea86 --- /dev/null +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -0,0 +1,114 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from pprint import pprint + +import torch +from torch.nn import functional as F + +import pytorch_lightning as pl +from pl_examples import cli_lightning_logo +from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule + + +class LitClassifier(pl.LightningModule): + """ + >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitClassifier( + (l1): Linear(...) + (l2): Linear(...) + ) + """ + + def __init__(self, hidden_dim=128, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("LitClassifier") + parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parent_parser + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + parser = MNISTDataModule.add_argparse_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dm = MNISTDataModule.from_argparse_args(args) + + # ------------ + # model + # ------------ + model = LitClassifier(args.hidden_dim, args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, datamodule=dm) + + # ------------ + # testing + # ------------ + result = trainer.test(model, datamodule=dm) + pprint(result) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/submit_ddp2_job.sh b/pl_examples/basic_examples/submit_ddp2_job.sh index 6e433f5fcd5752..026589a604c362 100755 --- a/pl_examples/basic_examples/submit_ddp2_job.sh +++ b/pl_examples/basic_examples/submit_ddp2_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 multi_node_ddp2_demo.py +srun python3 simple_image_classifier.py --accelerator 'ddp2' --gpus 2 --num_nodes 2 --max_epochs 5 diff --git a/pl_examples/basic_examples/submit_ddp_job.sh b/pl_examples/basic_examples/submit_ddp_job.sh index bf53a653686592..b4f5ff0a64d92f 100755 --- a/pl_examples/basic_examples/submit_ddp_job.sh +++ b/pl_examples/basic_examples/submit_ddp_job.sh @@ -24,4 +24,4 @@ source activate $1 # ------------------------- # run script from above -srun python3 multi_node_ddp_demo.py +srun python3 simple_image_classifier.py --accelerator 'ddp' --gpus 2 --num_nodes 2 --max_epochs 5 diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py new file mode 100644 index 00000000000000..4d9a23f48ca5db --- /dev/null +++ b/pl_examples/bug_report_model.py @@ -0,0 +1,156 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -------------------------------------------- +# -------------------------------------------- +# -------------------------------------------- +# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT +# -------------------------------------------- +# -------------------------------------------- +# -------------------------------------------- +import os + +import torch +from torch.utils.data import Dataset + +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningModule, Trainer + + +class RandomDataset(Dataset): + """ + >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS + <...bug_report_model.RandomDataset object at ...> + """ + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + """ + >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + BoringModel( + (layer): Linear(...) + ) + """ + + def __init__(self): + """ + Testing PL Module + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self.layer(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x['x'] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + +# NOTE: If you are using a cmd line to run your script, +# provide the cmd line as below. +# opt = "--max_epochs 1 --limit_train_batches 1".split(" ") +# parser = ArgumentParser() +# args = parser.parse_args(opt) + + +def test_run(): + + class TestModel(BoringModel): + + def on_train_epoch_start(self) -> None: + print('override any method to prove your bug') + + # fake data + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + val_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + test_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + + # model + model = TestModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model, train_data, val_data) + trainer.test(test_dataloaders=test_data) + + +if __name__ == '__main__': + cli_lightning_logo() + test_run() diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 42a0a936d9e340..4e148a18433a61 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -1,176 +1,216 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Computer vision example on Transfer Learning. - This computer vision example illustrates how one could fine-tune a pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see below) is -trained for 15 epochs. The training consists in three stages. From epoch 0 to -4, the feature extractor (the pre-trained network) is frozen except maybe for -the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm -layers (if `train_bn = True`) and the parameters of the classifier are trained -as a single parameters group with lr = 1e-2. From epoch 5 to 9, the last two -layer groups of the pre-trained network are unfrozen and added to the -optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3 for the -first parameter group in the optimizer). Eventually, from epoch 10, all the -remaining layer groups of the pre-trained network are unfrozen and added to -the optimizer as a third parameter group. From epoch 10, the parameters of the -pre-trained network are trained with lr = 1e-5 while those of the classifier -are trained with lr = 1e-4. +trained for 15 epochs. + +The training consists of three stages. + +From epoch 0 to 4, the feature extractor (the pre-trained network) is frozen except +maybe for the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm +layers (if `train_bn = True`) and the parameters of the classifier are trained as a +single parameters group with lr = 1e-2. + +From epoch 5 to 9, the last two layer groups of the pre-trained network are unfrozen +and added to the optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3 +for the first parameter group in the optimizer). + +Eventually, from epoch 10, all the remaining layer groups of the pre-trained network +are unfrozen and added to the optimizer as a third parameter group. From epoch 10, +the parameters of the pre-trained network are trained with lr = 1e-5 while those of +the classifier is trained with lr = 1e-4. Note: See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html """ - import argparse -from collections import OrderedDict +import logging +import os from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Generator, Union +from typing import Union -import pytorch_lightning as pl import torch import torch.nn.functional as F -from pytorch_lightning import _logger as log -from torch import optim +from torch import nn, optim from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from torchvision import models -from torchvision import transforms +from torchmetrics import Accuracy +from torchvision import models, transforms from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive -BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) -DATA_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip' +import pytorch_lightning as pl +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningDataModule +from pytorch_lightning.callbacks.finetuning import BaseFinetuning +from pytorch_lightning.utilities import rank_zero_info +log = logging.getLogger(__name__) +DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" -# --- Utility functions --- +# --- Finetuning Callback --- -def _make_trainable(module: torch.nn.Module) -> None: - """Unfreezes a given module. +class MilestonesFinetuning(BaseFinetuning): - Args: - module: The module to unfreeze - """ - for param in module.parameters(): - param.requires_grad = True - module.train() + def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False): + self.milestones = milestones + self.train_bn = train_bn + def freeze_before_training(self, pl_module: pl.LightningModule): + self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn) -def _recursive_freeze(module: torch.nn.Module, - train_bn: bool = True) -> None: - """Freezes the layers of a given module. + def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + if epoch == self.milestones[0]: + # unfreeze 5 last layers + self.unfreeze_and_add_param_group( + modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn + ) - Args: - module: The module to freeze - train_bn: If True, leave the BatchNorm layers in training mode - """ - children = list(module.children()) - if not children: - if not (isinstance(module, BN_TYPES) and train_bn): - for param in module.parameters(): - param.requires_grad = False - module.eval() - else: - # Make the BN layers trainable - _make_trainable(module) - else: - for child in children: - _recursive_freeze(module=child, train_bn=train_bn) + elif epoch == self.milestones[1]: + # unfreeze remaing layers + self.unfreeze_and_add_param_group( + modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn + ) -def freeze(module: torch.nn.Module, - n: Optional[int] = None, - train_bn: bool = True) -> None: - """Freezes the layers up to index n (if n is not None). +class CatDogImageDataModule(LightningDataModule): - Args: - module: The module to freeze (at least partially) - n: Max depth at which we stop freezing the layers. If None, all - the layers of the given module will be frozen. - train_bn: If True, leave the BatchNorm layers in training mode - """ - children = list(module.children()) - n_max = len(children) if n is None else int(n) + def __init__( + self, + dl_path: Union[str, Path], + num_workers: int = 0, + batch_size: int = 8, + ): + super().__init__() + + self._dl_path = dl_path + self._num_workers = num_workers + self._batch_size = batch_size - for child in children[:n_max]: - _recursive_freeze(module=child, train_bn=train_bn) + def prepare_data(self): + """Download images and prepare images datasets.""" + download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True) - for child in children[n_max:]: - _make_trainable(module=child) + @property + def data_path(self): + return Path(self._dl_path).joinpath("cats_and_dogs_filtered") + @property + def normalize_transform(self): + return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -def filter_params(module: torch.nn.Module, - train_bn: bool = True) -> Generator: - """Yields the trainable parameters of a given module. + @property + def train_transform(self): + return transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), self.normalize_transform + ]) - Args: - module: A given module - train_bn: If True, leave the BatchNorm layers in training mode + @property + def valid_transform(self): + return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform]) - Returns: - Generator - """ - children = list(module.children()) - if not children: - if not (isinstance(module, BN_TYPES) and train_bn): - for param in module.parameters(): - if param.requires_grad: - yield param - else: - for child in children: - for param in filter_params(module=child, train_bn=train_bn): - yield param - - -def _unfreeze_and_add_param_group(module: torch.nn.Module, - optimizer: Optimizer, - lr: Optional[float] = None, - train_bn: bool = True): - """Unfreezes a module and adds its parameters to an optimizer.""" - _make_trainable(module) - params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr) - optimizer.add_param_group( - {'params': filter_params(module=module, train_bn=train_bn), - 'lr': params_lr / 10., - }) + def create_dataset(self, root, transform): + return ImageFolder(root=root, transform=transform) + + def __dataloader(self, train: bool): + """Train/validation loaders.""" + if train: + dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform) + else: + dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform) + return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train) + + def train_dataloader(self): + log.info("Training data loaded.") + return self.__dataloader(train=True) + + def val_dataloader(self): + log.info("Validation data loaded.") + return self.__dataloader(train=False) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("CatDogImageDataModule") + parser.add_argument( + "--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers" + ) + parser.add_argument( + "--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size" + ) + return parent_parser # --- Pytorch-lightning module --- class TransferLearningModel(pl.LightningModule): - """Transfer Learning with pre-trained ResNet50. - Args: - hparams: Model hyperparameters - dl_path: Path where the data will be downloaded - """ - def __init__(self, - hparams: argparse.Namespace, - dl_path: Union[str, Path]) -> None: + def __init__( + self, + backbone: str = "resnet50", + train_bn: bool = True, + milestones: tuple = (5, 10), + batch_size: int = 32, + lr: float = 1e-2, + lr_scheduler_gamma: float = 1e-1, + num_workers: int = 6, + **kwargs, + ) -> None: + """ + Args: + dl_path: Path where the data will be downloaded + """ super().__init__() - self.hparams = hparams - self.dl_path = dl_path + self.backbone = backbone + self.train_bn = train_bn + self.milestones = milestones + self.batch_size = batch_size + self.lr = lr + self.lr_scheduler_gamma = lr_scheduler_gamma + self.num_workers = num_workers + self.__build_model() + self.train_acc = Accuracy() + self.valid_acc = Accuracy() + self.save_hyperparameters() + def __build_model(self): """Define model layers & loss.""" # 1. Load pre-trained network: - model_func = getattr(models, self.hparams.backbone) + model_func = getattr(models, self.backbone) backbone = model_func(pretrained=True) _layers = list(backbone.children())[:-1] - self.feature_extractor = torch.nn.Sequential(*_layers) - freeze(module=self.feature_extractor, train_bn=self.hparams.train_bn) + self.feature_extractor = nn.Sequential(*_layers) # 2. Classifier: - _fc_layers = [torch.nn.Linear(2048, 256), - torch.nn.Linear(256, 32), - torch.nn.Linear(32, 1)] - self.fc = torch.nn.Sequential(*_fc_layers) + _fc_layers = [ + nn.Linear(2048, 256), + nn.ReLU(), + nn.Linear(256, 32), + nn.Linear(32, 1), + ] + self.fc = nn.Sequential(*_fc_layers) # 3. Loss: self.loss_func = F.binary_cross_entropy_with_logits @@ -185,256 +225,132 @@ def forward(self, x): # 2. Classifier (returns logits): x = self.fc(x) - return x + return torch.sigmoid(x) - def loss(self, labels, logits): + def loss(self, logits, labels): return self.loss_func(input=logits, target=labels) - def train(self, mode=True): - super().train(mode=mode) - - epoch = self.current_epoch - if epoch < self.hparams.milestones[0] and mode: - # feature extractor is frozen (except for BatchNorm layers) - freeze(module=self.feature_extractor, - train_bn=self.hparams.train_bn) - - elif self.hparams.milestones[0] <= epoch < self.hparams.milestones[1] and mode: - # Unfreeze last two layers of the feature extractor - freeze(module=self.feature_extractor, - n=-2, - train_bn=self.hparams.train_bn) - - def on_epoch_start(self): - """Use `on_epoch_start` to unfreeze layers progressively.""" - optimizer = self.trainer.optimizers[0] - if self.current_epoch == self.hparams.milestones[0]: - _unfreeze_and_add_param_group(module=self.feature_extractor[-2:], - optimizer=optimizer, - train_bn=self.hparams.train_bn) - - elif self.current_epoch == self.hparams.milestones[1]: - _unfreeze_and_add_param_group(module=self.feature_extractor[:-2], - optimizer=optimizer, - train_bn=self.hparams.train_bn) - def training_step(self, batch, batch_idx): - # 1. Forward pass: x, y = batch y_logits = self.forward(x) y_true = y.view((-1, 1)).type_as(x) - y_bin = torch.ge(y_logits, 0) - # 2. Compute loss & accuracy: - train_loss = self.loss(y_true, y_logits) - num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum() + # 2. Compute loss + train_loss = self.loss(y_logits, y_true) - # 3. Outputs: - tqdm_dict = {'train_loss': train_loss} - output = OrderedDict({'loss': train_loss, - 'num_correct': num_correct, - 'log': tqdm_dict, - 'progress_bar': tqdm_dict}) + # 3. Compute accuracy: + self.log("train_acc", self.train_acc(y_logits, y_true.int()), prog_bar=True) - return output - - def training_epoch_end(self, outputs): - """Compute and log training loss and accuracy at the epoch level.""" - - train_loss_mean = torch.stack([output['loss'] - for output in outputs]).mean() - train_acc_mean = torch.stack([output['num_correct'] - for output in outputs]).sum().float() - train_acc_mean /= (len(outputs) * self.hparams.batch_size) - return {'log': {'train_loss': train_loss_mean, - 'train_acc': train_acc_mean, - 'step': self.current_epoch}} + return train_loss def validation_step(self, batch, batch_idx): - # 1. Forward pass: x, y = batch y_logits = self.forward(x) y_true = y.view((-1, 1)).type_as(x) - y_bin = torch.ge(y_logits, 0) - - # 2. Compute loss & accuracy: - val_loss = self.loss(y_true, y_logits) - num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum() - return {'val_loss': val_loss, - 'num_correct': num_correct} + # 2. Compute loss + self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True) - def validation_epoch_end(self, outputs): - """Compute and log validation loss and accuracy at the epoch level.""" - - val_loss_mean = torch.stack([output['val_loss'] - for output in outputs]).mean() - val_acc_mean = torch.stack([output['num_correct'] - for output in outputs]).sum().float() - val_acc_mean /= (len(outputs) * self.hparams.batch_size) - return {'log': {'val_loss': val_loss_mean, - 'val_acc': val_acc_mean, - 'step': self.current_epoch}} + # 3. Compute accuracy: + self.log("val_acc", self.valid_acc(y_logits, y_true.int()), prog_bar=True) def configure_optimizers(self): - optimizer = optim.Adam(filter(lambda p: p.requires_grad, - self.parameters()), - lr=self.hparams.lr) - - scheduler = MultiStepLR(optimizer, - milestones=self.hparams.milestones, - gamma=self.hparams.lr_scheduler_gamma) - + parameters = list(self.parameters()) + trainable_parameters = list(filter(lambda p: p.requires_grad, parameters)) + rank_zero_info( + f"The model will start training with only {len(trainable_parameters)} " + f"trainable parameters out of {len(parameters)}." + ) + optimizer = optim.Adam(trainable_parameters, lr=self.lr) + scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma) return [optimizer], [scheduler] - def prepare_data(self): - """Download images and prepare images datasets.""" - - # 1. Download the images - download_and_extract_archive(url=DATA_URL, - download_root=self.dl_path, - remove_finished=True) - - data_path = Path(self.dl_path).joinpath('cats_and_dogs_filtered') - - # 2. Load the data + preprocessing & data augmentation - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_dataset = ImageFolder(root=data_path.joinpath('train'), - transform=transforms.Compose([ - transforms.Resize((224, 224)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - - valid_dataset = ImageFolder(root=data_path.joinpath('validation'), - transform=transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - normalize, - ])) - - self.train_dataset = train_dataset - self.valid_dataset = valid_dataset - - def __dataloader(self, train): - """Train/validation loaders.""" - - _dataset = self.train_dataset if train else self.valid_dataset - loader = DataLoader(dataset=_dataset, - batch_size=self.hparams.batch_size, - num_workers=self.hparams.num_workers, - shuffle=True if train else False) - - return loader - - def train_dataloader(self): - log.info('Training data loaded.') - return self.__dataloader(train=True) - - def val_dataloader(self): - log.info('Validation data loaded.') - return self.__dataloader(train=False) - @staticmethod def add_model_specific_args(parent_parser): - parser = argparse.ArgumentParser(parents=[parent_parser]) - parser.add_argument('--backbone', - default='resnet50', - type=str, - metavar='BK', - help='Name (as in ``torchvision.models``) of the feature extractor') - parser.add_argument('--epochs', - default=15, - type=int, - metavar='N', - help='total number of epochs', - dest='nb_epochs') - parser.add_argument('--batch-size', - default=8, - type=int, - metavar='B', - help='batch size', - dest='batch_size') - parser.add_argument('--gpus', - type=int, - default=1, - help='number of gpus to use') - parser.add_argument('--lr', - '--learning-rate', - default=1e-2, - type=float, - metavar='LR', - help='initial learning rate', - dest='lr') - parser.add_argument('--lr-scheduler-gamma', - default=1e-1, - type=float, - metavar='LRG', - help='Factor by which the learning rate is reduced at each milestone', - dest='lr_scheduler_gamma') - parser.add_argument('--num-workers', - default=6, - type=int, - metavar='W', - help='number of CPU workers', - dest='num_workers') - parser.add_argument('--train-bn', - default=True, - type=bool, - metavar='TB', - help='Whether the BatchNorm layers should be trainable', - dest='train_bn') - parser.add_argument('--milestones', - default=[5, 10], - type=list, - metavar='M', - help='List of two epochs milestones') - return parser - - -def main(hparams: argparse.Namespace) -> None: + parser = parent_parser.add_argument_group("TransferLearningModel") + parser.add_argument( + "--backbone", + default="resnet50", + type=str, + metavar="BK", + help="Name (as in ``torchvision.models``) of the feature extractor", + ) + parser.add_argument( + "--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" + ) + parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") + parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use") + parser.add_argument( + "--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr" + ) + parser.add_argument( + "--lr-scheduler-gamma", + default=1e-1, + type=float, + metavar="LRG", + help="Factor by which the learning rate is reduced at each milestone", + dest="lr_scheduler_gamma", + ) + parser.add_argument( + "--train-bn", + default=False, + type=bool, + metavar="TB", + help="Whether the BatchNorm layers should be trainable", + dest="train_bn", + ) + parser.add_argument( + "--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones" + ) + return parent_parser + + +def main(args: argparse.Namespace) -> None: """Train the model. Args: - hparams: Model hyper-parameters + args: Model hyper-parameters Note: For the sake of the example, the images dataset will be downloaded to a temporary directory. """ - with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir: + datamodule = CatDogImageDataModule( + dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers + ) + model = TransferLearningModel(**vars(args)) + finetuning_callback = MilestonesFinetuning(milestones=args.milestones) - model = TransferLearningModel(hparams, dl_path=tmp_dir) + trainer = pl.Trainer( + weights_summary=None, + progress_bar_refresh_rate=1, + num_sanity_val_steps=0, + gpus=args.gpus, + max_epochs=args.nb_epochs, + callbacks=[finetuning_callback] + ) - trainer = pl.Trainer( - weights_summary=None, - show_progress_bar=True, - num_sanity_val_steps=0, - gpus=hparams.gpus, - min_epochs=hparams.nb_epochs, - max_epochs=hparams.nb_epochs) - - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) def get_args() -> argparse.Namespace: parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser.add_argument('--root-data-path', - metavar='DIR', - type=str, - default=Path.cwd().as_posix(), - help='Root directory where to download the data', - dest='root_data_path') + parent_parser.add_argument( + "--root-data-path", + metavar="DIR", + type=str, + default=Path.cwd().as_posix(), + help="Root directory where to download the data", + dest="root_data_path", + ) parser = TransferLearningModel.add_model_specific_args(parent_parser) + parser = CatDogImageDataModule.add_argparse_args(parser) return parser.parse_args() -if __name__ == '__main__': - +if __name__ == "__main__": + cli_lightning_logo() main(get_args()) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index cbe21fe2dbab3d..29fcf97de86db8 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ To run this template just do: python generative_adversarial_net.py @@ -7,24 +20,36 @@ tensorboard --logdir default """ import os -from argparse import ArgumentParser -from collections import OrderedDict +from argparse import ArgumentParser, Namespace import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -import torchvision -import torchvision.transforms as transforms +import torch.nn.functional as F # noqa from torch.utils.data import DataLoader -from torchvision.datasets import MNIST -from pytorch_lightning.core import LightningModule +from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo +from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer +if _TORCHVISION_AVAILABLE: + import torchvision + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + class Generator(nn.Module): - def __init__(self, latent_dim, img_shape): + """ + >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Generator( + (model): Sequential(...) + ) + """ + + def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape @@ -41,7 +66,7 @@ def block(in_feat, out_feat, normalize=True): *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), - nn.Tanh() + nn.Tanh(), ) def forward(self, z): @@ -51,6 +76,13 @@ def forward(self, z): class Discriminator(nn.Module): + """ + >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Discriminator( + (model): Sequential(...) + ) + """ + def __init__(self, img_shape): super().__init__() @@ -60,7 +92,6 @@ def __init__(self, img_shape): nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), - nn.Sigmoid(), ) def forward(self, img): @@ -71,58 +102,78 @@ def forward(self, img): class GAN(LightningModule): - - def __init__(self, hparams): + """ + >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + GAN( + (generator): Generator( + (model): Sequential(...) + ) + (discriminator): Discriminator( + (model): Sequential(...) + ) + ) + """ + + def __init__( + self, + img_shape: tuple = (1, 28, 28), + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + latent_dim: int = 100, + ): super().__init__() - self.hparams = hparams - # networks - mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=hparams.latent_dim, img_shape=mnist_shape) - self.discriminator = Discriminator(img_shape=mnist_shape) + self.save_hyperparameters() - # cache for generated images - self.generated_imgs = None - self.last_imgs = None + # networks + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) + self.discriminator = Discriminator(img_shape=img_shape) + + self.validation_z = torch.randn(8, self.hparams.latent_dim) + + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + + @staticmethod + def add_argparse_args(parent_parser: ArgumentParser, *, use_argument_group=True): + if use_argument_group: + parser = parent_parser.add_argument_group("pl.GAN") + parser_out = parent_parser + else: + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser_out = parser + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") + return parser_out def forward(self, z): return self.generator(z) - def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y) + @staticmethod + def adversarial_loss(y_hat, y): + return F.binary_cross_entropy_with_logits(y_hat, y) def training_step(self, batch, batch_idx, optimizer_idx): imgs, _ = batch - self.last_imgs = imgs + + # sample noise + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) + z = z.type_as(imgs) # train generator if optimizer_idx == 0: - # sample noise - z = torch.randn(imgs.shape[0], self.hparams.latent_dim) - z = z.type_as(imgs) - - # generate images - self.generated_imgs = self(z) - - # log sampled images - # sample_imgs = self.generated_imgs[:6] - # grid = torchvision.utils.make_grid(sample_imgs) - # self.logger.experiment.add_image('generated_images', grid, 0) - # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) valid = valid.type_as(imgs) # adversarial loss is binary cross-entropy - g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) + g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) tqdm_dict = {'g_loss': g_loss} - output = OrderedDict({ - 'loss': g_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output + self.log_dict(tqdm_dict) + return g_loss # train discriminator if optimizer_idx == 1: @@ -136,20 +187,16 @@ def training_step(self, batch, batch_idx, optimizer_idx): # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) - fake = fake.type_as(fake) + fake = fake.type_as(imgs) - fake_loss = self.adversarial_loss( - self.discriminator(self.generated_imgs.detach()), fake) + fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake) # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 tqdm_dict = {'d_loss': d_loss} - output = OrderedDict({ - 'loss': d_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output + self.log_dict(tqdm_dict) + + return d_loss def configure_optimizers(self): lr = self.hparams.lr @@ -160,50 +207,77 @@ def configure_optimizers(self): opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) return [opt_g, opt_d], [] - def train_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize([0.5], [0.5])]) - dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) - return DataLoader(dataset, batch_size=self.hparams.batch_size) - def on_epoch_end(self): - z = torch.randn(8, self.hparams.latent_dim) - z = z.type_as(self.last_imgs) + z = self.validation_z.type_as(self.generator.model[0].weight) # log sampled images sample_imgs = self(z) grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch) + self.logger.experiment.add_image('generated_images', grid, self.current_epoch) -def main(hparams): +class MNISTDataModule(LightningDataModule): + """ + >>> MNISTDataModule() # doctest: +ELLIPSIS + <...generative_adversarial_net.MNISTDataModule object at ...> + """ + + def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): + super().__init__() + self.batch_size = batch_size + self.data_path = data_path + self.num_workers = num_workers + + self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + self.dims = (1, 28, 28) + + def prepare_data(self, stage=None): + # Use this method to do things that might write to disk or that need to be done only from a single GPU + # in distributed settings. Like downloading the dataset for the first time. + MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor()) + + def setup(self, stage=None): + # There are also data operations you might want to perform on every GPU, such as applying transforms + # defined explicitly in your datamodule or assigned in init. + self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) + + +def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - model = GAN(hparams) + model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim) # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = Trainer() + # If use distubuted training PyTorch recommends to use DistributedDataParallel. + # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel + dm = MNISTDataModule.from_argparse_args(args) + trainer = Trainer.from_argparse_args(args) # ------------------------ # 3 START TRAINING # ------------------------ - trainer.fit(model) + trainer.fit(model, dm) if __name__ == '__main__': + cli_lightning_logo() parser = ArgumentParser() - parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") - parser.add_argument("--b1", type=float, default=0.5, - help="adam: decay of first order momentum of gradient") - parser.add_argument("--b2", type=float, default=0.999, - help="adam: decay of first order momentum of gradient") - parser.add_argument("--latent_dim", type=int, default=100, - help="dimensionality of the latent space") - - hparams = parser.parse_args() - - main(hparams) + + # Add program level args, if any. + # ------------------------ + # Add LightningDataLoader args + parser = MNISTDataModule.add_argparse_args(parser) + # Add model specific args + parser = GAN.add_argparse_args(parser) + # Add trainer args + parser = Trainer.add_argparse_args(parser) + # Parse all arguments + args = parser.parse_args() + + main(args) diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index c274cec90ddbbb..1b42edfde463bc 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -1,13 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py + +Before you can run this example, you will need to download the ImageNet dataset manually from the +`official website `_ and place it into a folder `path/to/imagenet`. + +Train on ImageNet with default parameters: + +.. code-block: bash + + python imagenet.py --data-path /path/to/imagenet + +or show all options you can change: + +.. code-block: bash + + python imagenet.py --help + """ -import argparse import os -import random -from collections import OrderedDict +from argparse import ArgumentParser, Namespace import torch -import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.nn.parallel import torch.optim as optim @@ -19,23 +45,46 @@ import torchvision.transforms as transforms import pytorch_lightning as pl +from pl_examples import cli_lightning_logo from pytorch_lightning.core import LightningModule -# pull out resnet names from torchvision models -MODEL_NAMES = sorted( - name for name in models.__dict__ - if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) -) - class ImageNetLightningModel(LightningModule): - def __init__(self, hparams): - """ - TODO: add docstring here - """ + """ + >>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ImageNetLightningModel( + (model): ResNet(...) + ) + """ + # pull out resnet names from torchvision models + MODEL_NAMES = sorted( + name for name in models.__dict__ + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) + ) + + def __init__( + self, + data_path: str, + arch: str = 'resnet18', + pretrained: bool = False, + lr: float = 0.1, + momentum: float = 0.9, + weight_decay: float = 1e-4, + batch_size: int = 4, + workers: int = 2, + **kwargs, + ): super().__init__() - self.hparams = hparams - self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) + self.save_hyperparameters() + self.arch = arch + self.pretrained = pretrained + self.lr = lr + self.momentum = momentum + self.weight_decay = weight_decay + self.data_path = data_path + self.batch_size = batch_size + self.workers = workers + self.model = models.__dict__[self.arch](pretrained=self.pretrained) def forward(self, x): return self.model(x) @@ -43,57 +92,24 @@ def forward(self, x): def training_step(self, batch, batch_idx): images, target = batch output = self(images) - loss_val = F.cross_entropy(output, target) + loss_train = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - - tqdm_dict = {'train_loss': loss_val} - output = OrderedDict({ - 'loss': loss_val, - 'acc1': acc1, - 'acc5': acc5, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - - return output + self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True) + self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True) + self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True) + return loss_train def validation_step(self, batch, batch_idx): images, target = batch output = self(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) + self.log('val_loss', loss_val, on_step=True, on_epoch=True) + self.log('val_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True) + self.log('val_acc5', acc5, on_step=True, on_epoch=True) - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc1': acc1, - 'val_acc5': acc5, - }) - - return output - - def validation_epoch_end(self, outputs): - - tqdm_dict = {} - - for metric_name in ["val_loss", "val_acc1", "val_acc5"]: - metric_total = 0 - - for output in outputs: - metric_value = output[metric_name] - - # reduce manually when using dp - if self.trainer.use_dp or self.trainer.use_ddp2: - metric_value = torch.mean(metric_value) - - metric_total += metric_value - - tqdm_dict[metric_name] = metric_total / len(outputs) - - result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]} - return result - - @classmethod - def __accuracy(cls, output, target, topk=(1,)): + @staticmethod + def __accuracy(output, target, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) @@ -105,18 +121,13 @@ def __accuracy(cls, output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def configure_optimizers(self): - optimizer = optim.SGD( - self.parameters(), - lr=self.hparams.lr, - momentum=self.hparams.momentum, - weight_decay=self.hparams.weight_decay - ) - scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1) + optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) + scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1**(epoch // 30)) return [optimizer], [scheduler] def train_dataloader(self): @@ -125,7 +136,7 @@ def train_dataloader(self): std=[0.229, 0.224, 0.225], ) - train_dir = os.path.join(self.hparams.data_path, 'train') + train_dir = os.path.join(self.data_path, 'train') train_dataset = datasets.ImageFolder( train_dir, transforms.Compose([ @@ -133,19 +144,14 @@ def train_dataloader(self): transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, - ])) - - if self.use_ddp: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - else: - train_sampler = None + ]) + ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=self.hparams.batch_size, - shuffle=(train_sampler is None), - num_workers=0, - sampler=train_sampler + batch_size=self.batch_size, + shuffle=True, + num_workers=self.workers, ) return train_loader @@ -154,85 +160,120 @@ def val_dataloader(self): mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) - val_dir = os.path.join(self.hparams.data_path, 'val') + val_dir = os.path.join(self.data_path, 'val') val_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(val_dir, transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])), - batch_size=self.hparams.batch_size, + datasets.ImageFolder( + val_dir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) + ), + batch_size=self.batch_size, shuffle=False, - num_workers=0, + num_workers=self.workers, ) return val_loader + def test_dataloader(self): + return self.val_dataloader() + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + def test_epoch_end(self, *args, **kwargs): + outputs = self.validation_epoch_end(*args, **kwargs) + + def substitute_val_keys(out): + return {k.replace('val', 'test'): v for k, v in out.items()} + + outputs = { + 'test_loss': outputs['val_loss'], + 'progress_bar': substitute_val_keys(outputs['progress_bar']), + 'log': substitute_val_keys(outputs['log']), + } + return outputs + @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover - parser = argparse.ArgumentParser(parents=[parent_parser]) - parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES, - help='model architecture: ' + - ' | '.join(MODEL_NAMES) + - ' (default: resnet18)') - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('--seed', type=int, default=42, - help='seed for initializing training. ') - parser.add_argument('-b', '--batch-size', default=256, type=int, - metavar='N', - help='mini-batch size (default: 256), this is the total ' - 'batch size of all GPUs on the current node when ' - 'using Data Parallel or Distributed Data Parallel') - parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate', dest='lr') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') - return parser - - -def get_args(): - parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser.add_argument('--data-path', metavar='DIR', type=str, - help='path to dataset') - parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str, - help='path to save output') - parent_parser.add_argument('--gpus', type=int, default=1, - help='how many gpus') - parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), - help='supports three options dp, ddp, ddp2') - parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true', - help='if true uses 16 bit precision') - parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') + parser = parent_parser.add_argument_group("ImageNetLightningModel") + parser.add_argument( + '-a', + '--arch', + metavar='ARCH', + default='resnet18', + choices=ImageNetLightningModel.MODEL_NAMES, + help=('model architecture: ' + ' | '.join(ImageNetLightningModel.MODEL_NAMES) + ' (default: resnet18)') + ) + parser.add_argument( + '-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)' + ) + parser.add_argument( + '-b', + '--batch-size', + default=256, + type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total batch size of all GPUs on the current node' + ' when using Data Parallel or Distributed Data Parallel' + ) + parser.add_argument( + '--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr' + ) + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') + parser.add_argument( + '--wd', + '--weight-decay', + default=1e-4, + type=float, + metavar='W', + help='weight decay (default: 1e-4)', + dest='weight_decay' + ) + parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') + return parent_parser - parser = ImageNetLightningModel.add_model_specific_args(parent_parser) - return parser.parse_args() - - -def main(hparams): - model = ImageNetLightningModel(hparams) - if hparams.seed is not None: - random.seed(hparams.seed) - torch.manual_seed(hparams.seed) - cudnn.deterministic = True - trainer = pl.Trainer( - default_root_dir=hparams.save_path, - gpus=hparams.gpus, - max_epochs=hparams.epochs, - distributed_backend=hparams.distributed_backend, - precision=16 if hparams.use_16bit else 32, - ) - if hparams.evaluate: - trainer.run_evaluation() + +def main(args: Namespace) -> None: + if args.seed is not None: + pl.seed_everything(args.seed) + + if args.accelerator == 'ddp': + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / max(1, args.gpus)) + args.workers = int(args.workers / max(1, args.gpus)) + + model = ImageNetLightningModel(**vars(args)) + trainer = pl.Trainer.from_argparse_args(args) + + if args.evaluate: + trainer.test(model) else: trainer.fit(model) +def run_cli(): + parent_parser = ArgumentParser(add_help=False) + parent_parser = pl.Trainer.add_argparse_args(parent_parser) + parent_parser.add_argument('--data-path', metavar='DIR', type=str, help='path to dataset') + parent_parser.add_argument( + '-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set' + ) + parent_parser.add_argument('--seed', type=int, default=42, help='seed for initializing training.') + parser = ImageNetLightningModel.add_model_specific_args(parent_parser) + parser.set_defaults( + profiler="simple", + deterministic=True, + max_epochs=90, + ) + args = parser.parse_args() + main(args) + + if __name__ == '__main__': - main(get_args()) + cli_lightning_logo() + run_cli() diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 4301957afa58de..4d90faeb45bcfb 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -1,54 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Deep Reinforcement Learning: Deep Q-network (DQN) -This example is based on https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- -Second-Edition/blob/master/Chapter06/02_dqn_pong.py - The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the classic CartPole environment. -To run the template just run: -python reinforce_learn_Qnet.py +To run the template, just run: +`python reinforce_learn_Qnet.py` -After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up TensorBoard to -see the metrics: +After ~1500 steps, you will see the total_reward hitting the max score of 200. +Open up TensorBoard to see the metrics: -tensorboard --logdir default -""" +`tensorboard --logdir default` -import pytorch_lightning as pl +References +---------- -from typing import Tuple, List +[1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- +Second-Edition/blob/master/Chapter06/02_dqn_pong.py +""" import argparse -from collections import OrderedDict, deque, namedtuple +from collections import deque, namedtuple, OrderedDict +from typing import List, Tuple import gym import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torch.optim import Optimizer +from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torch.utils.data.dataset import IterableDataset +import pytorch_lightning as pl +from pl_examples import cli_lightning_logo + class DQN(nn.Module): """ Simple MLP network - Args: - obs_size: observation/state size of the environment - n_actions: number of discrete actions available in the environment - hidden_size: size of hidden layers + >>> DQN(10, 5) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DQN( + (net): Sequential(...) + ) """ def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): + """ + Args: + obs_size: observation/state size of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ super(DQN, self).__init__() self.net = nn.Sequential( nn.Linear(obs_size, hidden_size), nn.ReLU(), - nn.Linear(hidden_size, n_actions) + nn.Linear(hidden_size, n_actions), ) def forward(self, x): @@ -56,20 +78,22 @@ def forward(self, x): # Named tuple for storing experience steps gathered in training -Experience = namedtuple( - 'Experience', field_names=['state', 'action', 'reward', - 'done', 'new_state']) +Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'done', 'new_state']) class ReplayBuffer: """ Replay Buffer for storing past experiences allowing the agent to learn from them - Args: - capacity: size of the buffer + >>> ReplayBuffer(5) # doctest: +ELLIPSIS + <...reinforce_learn_Qnet.ReplayBuffer object at ...> """ def __init__(self, capacity: int) -> None: + """ + Args: + capacity: size of the buffer + """ self.buffer = deque(maxlen=capacity) def __len__(self) -> int: @@ -88,8 +112,13 @@ def sample(self, batch_size: int) -> Tuple: indices = np.random.choice(len(self.buffer), batch_size, replace=False) states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices]) - return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), - np.array(dones, dtype=np.bool), np.array(next_states)) + return ( + np.array(states), + np.array(actions), + np.array(rewards, dtype=np.float32), + np.array(dones, dtype=np.bool), + np.array(next_states), + ) class RLDataset(IterableDataset): @@ -97,12 +126,16 @@ class RLDataset(IterableDataset): Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training - Args: - buffer: replay buffer - sample_size: number of experiences to sample at a time + >>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS + <...reinforce_learn_Qnet.RLDataset object at ...> """ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: + """ + Args: + buffer: replay buffer + sample_size: number of experiences to sample at a time + """ self.buffer = buffer self.sample_size = sample_size @@ -114,14 +147,20 @@ def __iter__(self) -> Tuple: class Agent: """ - Base Agent class handeling the interaction with the environment + Base Agent class handling the interaction with the environment - Args: - env: training environment - replay_buffer: replay buffer storing experiences + >>> env = gym.make("CartPole-v0") + >>> buffer = ReplayBuffer(10) + >>> Agent(env, buffer) # doctest: +ELLIPSIS + <...reinforce_learn_Qnet.Agent object at ...> """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: + """ + Args: + env: training environment + replay_buffer: replay buffer storing experiences + """ self.env = env self.replay_buffer = replay_buffer self.reset() @@ -188,24 +227,58 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') - class DQNLightning(pl.LightningModule): - """ Basic DQN Model """ - - def __init__(self, hparams: argparse.Namespace) -> None: - super().__init__() - self.hparams = hparams + """ Basic DQN Model + + >>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DQNLightning( + (net): DQN( + (net): Sequential(...) + ) + (target_net): DQN( + (net): Sequential(...) + ) + ) + """ - self.env = gym.make(self.hparams.env) + def __init__( + self, + env: str, + replay_size: int = 200, + warm_start_steps: int = 200, + gamma: float = 0.99, + eps_start: float = 1.0, + eps_end: float = 0.01, + eps_last_frame: int = 200, + sync_rate: int = 10, + lr: float = 1e-2, + episode_length: int = 50, + batch_size: int = 4, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.replay_size = replay_size + self.warm_start_steps = warm_start_steps + self.gamma = gamma + self.eps_start = eps_start + self.eps_end = eps_end + self.eps_last_frame = eps_last_frame + self.sync_rate = sync_rate + self.lr = lr + self.episode_length = episode_length + self.batch_size = batch_size + + self.env = gym.make(env) obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n self.net = DQN(obs_size, n_actions) self.target_net = DQN(obs_size, n_actions) - self.buffer = ReplayBuffer(self.hparams.replay_size) + self.buffer = ReplayBuffer(self.replay_size) self.agent = Agent(self.env, self.buffer) self.total_reward = 0 self.episode_reward = 0 - self.populate(self.hparams.warm_start_steps) + self.populate(self.warm_start_steps) def populate(self, steps: int = 1000) -> None: """ @@ -250,14 +323,14 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor next_state_values[dones] = 0.0 next_state_values = next_state_values.detach() - expected_state_action_values = next_state_values * self.hparams.gamma + rewards + expected_state_action_values = next_state_values * self.gamma + rewards return nn.MSELoss()(state_action_values, expected_state_action_values) def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: """ Carries out a single step through the environment to update the replay buffer. - Then calculates loss based on the minibatch recieved + Then calculates loss based on the minibatch received Args: batch: current mini batch of replay data @@ -267,8 +340,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O Training loss and log metrics """ device = self.get_device(batch) - epsilon = max(self.hparams.eps_end, self.hparams.eps_start - - self.global_step + 1 / self.hparams.eps_last_frame) + epsilon = max(self.eps_end, self.eps_start - self.global_step + 1 / self.eps_last_frame) # step through environment with agent reward, done = self.agent.play_step(self.net, epsilon, device) @@ -282,27 +354,30 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O self.episode_reward = 0 # Soft update of target network - if self.global_step % self.hparams.sync_rate == 0: + if self.global_step % self.sync_rate == 0: self.target_net.load_state_dict(self.net.state_dict()) - log = {'total_reward': torch.tensor(self.total_reward).to(device), - 'reward': torch.tensor(reward).to(device), - 'steps': torch.tensor(self.global_step).to(device)} + log = { + 'total_reward': torch.tensor(self.total_reward).to(device), + 'reward': torch.tensor(reward).to(device), + 'steps': torch.tensor(self.global_step).to(device) + } return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log}) def configure_optimizers(self) -> List[Optimizer]: """Initialize Adam optimizer""" - optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) + optimizer = optim.Adam(self.net.parameters(), lr=self.lr) return [optimizer] def __dataloader(self) -> DataLoader: """Initialize the Replay Buffer dataset used for retrieving experiences""" - dataset = RLDataset(self.buffer, self.hparams.episode_length) - dataloader = DataLoader(dataset=dataset, - batch_size=self.hparams.batch_size, - sampler=None - ) + dataset = RLDataset(self.buffer, self.episode_length) + dataloader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + sampler=None, + ) return dataloader def train_dataloader(self) -> DataLoader: @@ -313,45 +388,47 @@ def get_device(self, batch) -> str: """Retrieve device currently being used by minibatch""" return batch[0].device.index if self.on_gpu else 'cpu' + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = parent_parser.add_argument_group("DQNLightning") + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network") + parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer") + parser.add_argument( + "--warm_start_steps", + type=int, + default=1000, + help="how many samples do we use to fill our buffer at the start of training" + ) + parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying") + parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") + parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") + parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") + return parent_parser + -def main(hparams) -> None: - model = DQNLightning(hparams) +def main(args) -> None: + model = DQNLightning(**vars(args)) trainer = pl.Trainer( gpus=1, - distributed_backend='dp', - early_stop_callback=False, - val_check_interval=100 + accelerator='dp', + val_check_interval=100, ) trainer.fit(model) if __name__ == '__main__': + cli_lightning_logo() torch.manual_seed(0) np.random.seed(0) - parser = argparse.ArgumentParser() - parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") - parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") - parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") - parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") - parser.add_argument("--sync_rate", type=int, default=10, - help="how many frames do we update the target network") - parser.add_argument("--replay_size", type=int, default=1000, - help="capacity of the replay buffer") - parser.add_argument("--warm_start_size", type=int, default=1000, - help="how many samples do we use to fill our buffer at the start of training") - parser.add_argument("--eps_last_frame", type=int, default=1000, - help="what frame should epsilon stop decaying") - parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") - parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") - parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") - parser.add_argument("--max_episode_reward", type=int, default=200, - help="max episode reward in the environment") - parser.add_argument("--warm_start_steps", type=int, default=1000, - help="max episode reward in the environment") - + parser = argparse.ArgumentParser(add_help=False) + parser = DQNLightning.add_model_specific_args(parser) args = parser.parse_args() main(args) diff --git a/pl_examples/domain_templates/reinforce_learn_ppo.py b/pl_examples/domain_templates/reinforce_learn_ppo.py new file mode 100644 index 00000000000000..68ecc3fb22db02 --- /dev/null +++ b/pl_examples/domain_templates/reinforce_learn_ppo.py @@ -0,0 +1,490 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch Lightning implementation of Proximal Policy Optimization (PPO) + +Paper authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov + +The example implements PPO compatible to work with any continous or discrete action-space environments via OpenAI Gym. + +To run the template, just run: +`python reinforce_learn_ppo.py` + +References +---------- +[1] https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py +[2] https://github.com/openai/spinningup +[3] https://github.com/sid-sundrani/ppo_lightning +""" +import argparse +from typing import Callable, Iterable, List, Tuple + +import gym +import torch +from torch import nn +from torch.distributions import Categorical, Normal +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader, IterableDataset + +import pytorch_lightning as pl +from pl_examples import cli_lightning_logo + + +def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): + """ + Simple Multi-Layer Perceptron network + """ + network = nn.Sequential( + nn.Linear(input_shape[0], hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, n_actions), + ) + + return network + + +class ActorCategorical(nn.Module): + """ + Policy network, for discrete action spaces, which returns a distribution + and an action given an observation + """ + + def __init__(self, actor_net): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + """ + super().__init__() + + self.actor_net = actor_net + + def forward(self, states): + logits = self.actor_net(states) + pi = Categorical(logits=logits) + actions = pi.sample() + + return pi, actions + + def get_log_prob(self, pi: Categorical, actions: torch.Tensor): + """ + Takes in a distribution and actions and returns log prob of actions under the distribution + + Args: + pi: torch distribution + actions: actions taken by distribution + + Returns: + log probability of the acition under pi + """ + return pi.log_prob(actions) + + +class ActorContinous(nn.Module): + """ + Policy network, for continous action spaces, which returns a distribution + and an action given an observation + """ + + def __init__(self, actor_net, act_dim): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + """ + super().__init__() + self.actor_net = actor_net + log_std = -0.5 * torch.ones(act_dim, dtype=torch.float) + self.log_std = nn.Parameter(log_std) + + def forward(self, states): + mu = self.actor_net(states) + std = torch.exp(self.log_std) + pi = Normal(loc=mu, scale=std) + actions = pi.sample() + + return pi, actions + + def get_log_prob(self, pi: Normal, actions: torch.Tensor): + """ + Takes in a distribution and actions and returns log prob of actions under the distribution + + Args: + pi: torch distribution + actions: actions taken by distribution + + Returns: + log probability of the acition under pi + """ + return pi.log_prob(actions).sum(axis=-1) + + +class ExperienceSourceDataset(IterableDataset): + """ + Implementation from PyTorch Lightning Bolts: + https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/datamodules/experience_source.py + + Basic experience source dataset. Takes a generate_batch function that returns an iterator. + The logic for the experience source and how the batch is generated is defined the Lightning model itself + """ + + def __init__(self, generate_batch: Callable): + self.generate_batch = generate_batch + + def __iter__(self) -> Iterable: + iterator = self.generate_batch() + return iterator + + +class PPOLightning(pl.LightningModule): + """ + PyTorch Lightning implementation of PPO. + + Example: + model = PPOLightning("CartPole-v0") + Train: + trainer = Trainer() + trainer.fit(model) + """ + + def __init__( + self, + env: str, + gamma: float = 0.99, + lam: float = 0.95, + lr_actor: float = 3e-4, + lr_critic: float = 1e-3, + max_episode_len: float = 200, + batch_size: int = 512, + steps_per_epoch: int = 2048, + nb_optim_iters: int = 4, + clip_ratio: float = 0.2, + **kwargs, + ) -> None: + """ + Args: + env: gym environment tag + gamma: discount factor + lam: advantage discount factor (lambda in the paper) + lr_actor: learning rate of actor network + lr_critic: learning rate of critic network + max_episode_len: maximum number interactions (actions) in an episode + batch_size: batch_size when training network- can simulate number of policy updates performed per epoch + steps_per_epoch: how many action-state pairs to rollout for trajectory collection per epoch + nb_optim_iters: how many steps of gradient descent to perform on each batch + clip_ratio: hyperparameter for clipping in the policy objective + """ + super().__init__() + + # Hyperparameters + self.lr_actor = lr_actor + self.lr_critic = lr_critic + self.steps_per_epoch = steps_per_epoch + self.nb_optim_iters = nb_optim_iters + self.batch_size = batch_size + self.gamma = gamma + self.lam = lam + self.max_episode_len = max_episode_len + self.clip_ratio = clip_ratio + self.save_hyperparameters() + + self.env = gym.make(env) + # value network + self.critic = create_mlp(self.env.observation_space.shape, 1) + # policy network (agent) + if isinstance(self.env.action_space, gym.spaces.box.Box): + act_dim = self.env.action_space.shape[0] + actor_mlp = create_mlp(self.env.observation_space.shape, act_dim) + self.actor = ActorContinous(actor_mlp, act_dim) + elif isinstance(self.env.action_space, gym.spaces.discrete.Discrete): + actor_mlp = create_mlp(self.env.observation_space.shape, self.env.action_space.n) + self.actor = ActorCategorical(actor_mlp) + else: + raise NotImplementedError( + 'Env action space should be of type Box (continous) or Discrete (categorical).' + f' Got type: {type(self.env.action_space)}' + ) + + self.batch_states = [] + self.batch_actions = [] + self.batch_adv = [] + self.batch_qvals = [] + self.batch_logp = [] + + self.ep_rewards = [] + self.ep_values = [] + self.epoch_rewards = [] + + self.episode_step = 0 + self.avg_ep_reward = 0 + self.avg_ep_len = 0 + self.avg_reward = 0 + + self.state = torch.FloatTensor(self.env.reset()) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Passes in a state x through the network and returns the policy and a sampled action + + Args: + x: environment state + + Returns: + Tuple of policy and action + """ + pi, action = self.actor(x) + value = self.critic(x) + + return pi, action, value + + def discount_rewards(self, rewards: List[float], discount: float) -> List[float]: + """Calculate the discounted rewards of all rewards in list + + Args: + rewards: list of rewards/advantages + + Returns: + list of discounted rewards/advantages + """ + assert isinstance(rewards[0], float) + + cumul_reward = [] + sum_r = 0.0 + + for r in reversed(rewards): + sum_r = (sum_r * discount) + r + cumul_reward.append(sum_r) + + return list(reversed(cumul_reward)) + + def calc_advantage(self, rewards: List[float], values: List[float], last_value: float) -> List[float]: + """Calculate the advantage given rewards, state values, and the last value of episode + + Args: + rewards: list of episode rewards + values: list of state values from critic + last_value: value of last state of episode + + Returns: + list of advantages + """ + rews = rewards + [last_value] + vals = values + [last_value] + # GAE + delta = [rews[i] + self.gamma * vals[i + 1] - vals[i] for i in range(len(rews) - 1)] + adv = self.discount_rewards(delta, self.gamma * self.lam) + + return adv + + def generate_trajectory_samples(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Contains the logic for generating trajectory data to train policy and value network + Yield: + Tuple of Lists containing tensors for states, actions, log probs, qvals and advantage + """ + + for step in range(self.steps_per_epoch): + self.state = self.state.to(device=self.device) + + with torch.no_grad(): + pi, action, value = self(self.state) + log_prob = self.actor.get_log_prob(pi, action) + + next_state, reward, done, _ = self.env.step(action.cpu().numpy()) + + self.episode_step += 1 + + self.batch_states.append(self.state) + self.batch_actions.append(action) + self.batch_logp.append(log_prob) + + self.ep_rewards.append(reward) + self.ep_values.append(value.item()) + + self.state = torch.FloatTensor(next_state) + + epoch_end = step == (self.steps_per_epoch - 1) + terminal = len(self.ep_rewards) == self.max_episode_len + + if epoch_end or done or terminal: + # if trajectory ends abtruptly, boostrap value of next state + if (terminal or epoch_end) and not done: + self.state = self.state.to(device=self.device) + with torch.no_grad(): + _, _, value = self(self.state) + last_value = value.item() + steps_before_cutoff = self.episode_step + else: + last_value = 0 + steps_before_cutoff = 0 + + # discounted cumulative reward + self.batch_qvals += self.discount_rewards(self.ep_rewards + [last_value], self.gamma)[:-1] + # advantage + self.batch_adv += self.calc_advantage(self.ep_rewards, self.ep_values, last_value) + # logs + self.epoch_rewards.append(sum(self.ep_rewards)) + # reset params + self.ep_rewards = [] + self.ep_values = [] + self.episode_step = 0 + self.state = torch.FloatTensor(self.env.reset()) + + if epoch_end: + train_data = zip( + self.batch_states, self.batch_actions, self.batch_logp, self.batch_qvals, self.batch_adv + ) + + for state, action, logp_old, qval, adv in train_data: + yield state, action, logp_old, qval, adv + + self.batch_states.clear() + self.batch_actions.clear() + self.batch_adv.clear() + self.batch_logp.clear() + self.batch_qvals.clear() + + # logging + self.avg_reward = sum(self.epoch_rewards) / self.steps_per_epoch + + # if epoch ended abruptly, exlude last cut-short episode to prevent stats skewness + epoch_rewards = self.epoch_rewards + if not done: + epoch_rewards = epoch_rewards[:-1] + + total_epoch_reward = sum(epoch_rewards) + nb_episodes = len(epoch_rewards) + + self.avg_ep_reward = total_epoch_reward / nb_episodes + self.avg_ep_len = (self.steps_per_epoch - steps_before_cutoff) / nb_episodes + + self.epoch_rewards.clear() + + def actor_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor: + pi, _ = self.actor(state) + logp = self.actor.get_log_prob(pi, action) + ratio = torch.exp(logp - logp_old) + clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * adv + loss_actor = -(torch.min(ratio * adv, clip_adv)).mean() + return loss_actor + + def critic_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor: + value = self.critic(state) + loss_critic = (qval - value).pow(2).mean() + return loss_critic + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx, optimizer_idx): + """ + Carries out a single update to actor and critic network from a batch of replay buffer. + + Args: + batch: batch of replay buffer/trajectory data + batch_idx: not used + optimizer_idx: idx that controls optimizing actor or critic network + + Returns: + loss + """ + state, action, old_logp, qval, adv = batch + + # normalize advantages + adv = (adv - adv.mean()) / adv.std() + + self.log("avg_ep_len", self.avg_ep_len, prog_bar=True, on_step=False, on_epoch=True) + self.log("avg_ep_reward", self.avg_ep_reward, prog_bar=True, on_step=False, on_epoch=True) + self.log("avg_reward", self.avg_reward, prog_bar=True, on_step=False, on_epoch=True) + + if optimizer_idx == 0: + loss_actor = self.actor_loss(state, action, old_logp, qval, adv) + self.log('loss_actor', loss_actor, on_step=False, on_epoch=True, prog_bar=True, logger=True) + + return loss_actor + + elif optimizer_idx == 1: + loss_critic = self.critic_loss(state, action, old_logp, qval, adv) + self.log('loss_critic', loss_critic, on_step=False, on_epoch=True, prog_bar=False, logger=True) + + return loss_critic + + def configure_optimizers(self) -> List[Optimizer]: + """ Initialize Adam optimizer""" + optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.lr_actor) + optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=self.lr_critic) + + return optimizer_actor, optimizer_critic + + def optimizer_step(self, *args, **kwargs): + """ + Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic + for each data sample. + """ + for _ in range(self.nb_optim_iters): + super().optimizer_step(*args, **kwargs) + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + dataset = ExperienceSourceDataset(self.generate_trajectory_samples) + dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size) + return dataloader + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = parent_parser.add_argument_group("PPOLightning") + parser.add_argument("--env", type=str, default="CartPole-v0") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--lam", type=float, default=0.95, help="advantage discount factor") + parser.add_argument("--lr_actor", type=float, default=3e-4, help="learning rate of actor network") + parser.add_argument("--lr_critic", type=float, default=1e-3, help="learning rate of critic network") + parser.add_argument("--max_episode_len", type=int, default=1000, help="capacity of the replay buffer") + parser.add_argument("--batch_size", type=int, default=512, help="batch_size when training network") + parser.add_argument( + "--steps_per_epoch", + type=int, + default=2048, + help="how many action-state pairs to rollout for trajectory collection per epoch" + ) + parser.add_argument( + "--nb_optim_iters", type=int, default=4, help="how many steps of gradient descent to perform on each batch" + ) + parser.add_argument( + "--clip_ratio", type=float, default=0.2, help="hyperparameter for clipping in the policy objective" + ) + + return parent_parser + + +def main(args) -> None: + model = PPOLightning(**vars(args)) + + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model) + + +if __name__ == '__main__': + cli_lightning_logo() + pl.seed_everything(0) + + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = pl.Trainer.add_argparse_args(parent_parser) + + parser = PPOLightning.add_model_specific_args(parent_parser) + args = parser.parse_args() + + main(args) diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 4604b6454db981..1ae10d40a4e537 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -1,5 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -from argparse import ArgumentParser +import random +from argparse import ArgumentParser, Namespace import numpy as np import torch @@ -7,16 +22,29 @@ import torchvision.transforms as transforms from PIL import Image from torch.utils.data import DataLoader, Dataset -import random import pytorch_lightning as pl -from pl_examples.models.unet import UNet +from pl_examples import cli_lightning_logo +from pl_examples.domain_templates.unet import UNet from pytorch_lightning.loggers import WandbLogger DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) +def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)): + """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded.""" + path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH) + path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH) + for p_dir in (path_dir_images, path_dir_masks): + os.makedirs(p_dir, exist_ok=True) + for i in range(3): + path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png') + Image.new('RGB', image_dims).save(path_img) + path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png') + Image.new('L', image_dims).save(path_mask) + + class KITTI(Dataset): """ Class for KITTI Semantic Segmentation Benchmark dataset @@ -38,6 +66,12 @@ class KITTI(Dataset): In the `get_item` function, images and masks are resized to the given `img_size`, masks are encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only (mask does not usually require transforms, but they can be implemented in a similar way). + + >>> from pl_examples import _DATASETS_PATH + >>> dataset_path = os.path.join(_DATASETS_PATH, "Kitti") + >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512)) + >>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + <...semantic_segmentation.KITTI object at ...> """ IMAGE_PATH = os.path.join('training', 'image_2') MASK_PATH = os.path.join('training', 'semantic') @@ -126,20 +160,49 @@ class SegModel(pl.LightningModule): It uses the FCN ResNet50 model as an example. Adam optimizer is used along with Cosine Annealing learning rate scheduler. + + >>> from pl_examples import _DATASETS_PATH + >>> dataset_path = os.path.join(_DATASETS_PATH, "Kitti") + >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512)) + >>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + SegModel( + (net): UNet( + (layers): ModuleList( + (0): DoubleConv(...) + (1): Down(...) + (2): Down(...) + (3): Up(...) + (4): Up(...) + (5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + ) """ - def __init__(self, hparams): - super().__init__() - self.hparams = hparams - self.data_path = hparams.data_path - self.batch_size = hparams.batch_size - self.learning_rate = hparams.lr - self.net = UNet(num_classes=19, num_layers=hparams.num_layers, - features_start=hparams.features_start, bilinear=hparams.bilinear) + def __init__( + self, + data_path: str, + batch_size: int = 4, + lr: float = 1e-3, + num_layers: int = 3, + features_start: int = 64, + bilinear: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.data_path = data_path + self.batch_size = batch_size + self.lr = lr + self.num_layers = num_layers + self.features_start = features_start + self.bilinear = bilinear + + self.net = UNet( + num_classes=19, num_layers=self.num_layers, features_start=self.features_start, bilinear=self.bilinear + ) self.transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], - std=[0.32064945, 0.32098866, 0.32325324]) + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) ]) self.trainset = KITTI(self.data_path, split='train', transform=self.transform) self.validset = KITTI(self.data_path, split='valid', transform=self.transform) @@ -152,9 +215,9 @@ def training_step(self, batch, batch_nb): img = img.float() mask = mask.long() out = self(img) - loss_val = F.cross_entropy(out, mask, ignore_index=250) - log_dict = {'train_loss': loss_val} - return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict} + loss = F.cross_entropy(out, mask, ignore_index=250) + log_dict = {'train_loss': loss} + return {'loss': loss, 'log': log_dict, 'progress_bar': log_dict} def validation_step(self, batch, batch_idx): img, mask = batch @@ -165,7 +228,7 @@ def validation_step(self, batch, batch_idx): return {'val_loss': loss_val} def validation_epoch_end(self, outputs): - loss_val = sum(output['val_loss'] for output in outputs) / len(outputs) + loss_val = torch.stack([x['val_loss'] for x in outputs]).mean() log_dict = {'val_loss': loss_val} return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict} @@ -180,12 +243,28 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False) - -def main(hparams): + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = parent_parser.add_argument_group("SegModel") + parser.add_argument("--data_path", type=str, help="path where dataset is stored") + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") + parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") + parser.add_argument( + "--bilinear", + action='store_true', + default=False, + help="whether to use bilinear interpolation or transposed" + ) + return parent_parser + + +def main(hparams: Namespace): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - model = SegModel(hparams) + model = SegModel(**vars(hparams)) # ------------------------ # 2 SET LOGGER @@ -200,14 +279,7 @@ def main(hparams): # ------------------------ # 3 INIT TRAINER # ------------------------ - trainer = pl.Trainer( - gpus=hparams.gpus, - logger=logger, - max_epochs=hparams.epochs, - accumulate_grad_batches=hparams.grad_batches, - distributed_backend=hparams.distributed_backend, - precision=16 if hparams.use_amp else 32, - ) + trainer = pl.Trainer.from_argparse_args(hparams) # ------------------------ # 5 START TRAINING @@ -216,22 +288,9 @@ def main(hparams): if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument("--data_path", type=str, help="path where dataset is stored") - parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs") - parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), - help='supports three options dp, ddp, ddp2') - parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision') - parser.add_argument("--batch_size", type=int, default=4, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") - parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") - parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") - parser.add_argument("--bilinear", action='store_true', default=False, - help="whether to use bilinear interpolation or transposed") - parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate") - parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train") - parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases") - + cli_lightning_logo() + parser = ArgumentParser(add_help=False) + parser = SegModel.add_model_specific_args(parser) hparams = parser.parse_args() main(hparams) diff --git a/pl_examples/models/unet.py b/pl_examples/domain_templates/unet.py similarity index 53% rename from pl_examples/models/unet.py rename to pl_examples/domain_templates/unet.py index 5e85802bfe695f..f083ae434bd33f 100644 --- a/pl_examples/models/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,20 +22,33 @@ class UNet(nn.Module): Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation Link - https://arxiv.org/abs/1505.04597 - Parameters: - num_classes: Number of output classes required (default 19 for KITTI dataset) - num_layers: Number of layers in each side of U-net - features_start: Number of features in first layer - bilinear: Whether to use bilinear interpolation or transposed - convolutions for upsampling. + >>> UNet(num_classes=2, num_layers=3) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + UNet( + (layers): ModuleList( + (0): DoubleConv(...) + (1): Down(...) + (2): Down(...) + (3): Up(...) + (4): Up(...) + (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + ) """ def __init__( - self, num_classes: int = 19, - num_layers: int = 5, - features_start: int = 64, - bilinear: bool = False + self, + num_classes: int = 19, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False, ): + """ + Args: + num_classes: Number of output classes required (default 19 for KITTI dataset) + num_layers: Number of layers in each side of U-net + features_start: Number of features in first layer + bilinear: Whether to use bilinear interpolation or transposed convolutions for upsampling. + """ super().__init__() self.num_layers = num_layers @@ -33,7 +60,7 @@ def __init__( feats *= 2 for _ in range(num_layers - 1): - layers.append(Up(feats, feats // 2), bilinear) + layers.append(Up(feats, feats // 2, bilinear)) feats //= 2 layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) @@ -55,6 +82,11 @@ class DoubleConv(nn.Module): """ Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2 + + >>> DoubleConv(4, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DoubleConv( + (net): Sequential(...) + ) """ def __init__(self, in_ch: int, out_ch: int): @@ -65,7 +97,7 @@ def __init__(self, in_ch: int, out_ch: int): nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) def forward(self, x): @@ -75,14 +107,21 @@ def forward(self, x): class Down(nn.Module): """ Combination of MaxPool2d and DoubleConv in series + + >>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Down( + (net): Sequential( + (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (1): DoubleConv( + (net): Sequential(...) + ) + ) + ) """ def __init__(self, in_ch: int, out_ch: int): super().__init__() - self.net = nn.Sequential( - nn.MaxPool2d(kernel_size=2, stride=2), - DoubleConv(in_ch, out_ch) - ) + self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch)) def forward(self, x): return self.net(x) @@ -93,13 +132,24 @@ class Up(nn.Module): Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map from contracting path, followed by double 3x3 convolution. + + >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Up( + (upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2)) + (conv): DoubleConv( + (net): Sequential(...) + ) + ) """ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): super().__init__() self.upsample = None if bilinear: - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(in_ch, in_ch // 2, kernel_size=1), + ) else: self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2) diff --git a/pl_examples/models/lightning_template.py b/pl_examples/models/lightning_template.py deleted file mode 100644 index 13b3bc67a912b4..00000000000000 --- a/pl_examples/models/lightning_template.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -Example template for defining a system. -""" -import os -from argparse import ArgumentParser -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms -from torch import optim -from torch.utils.data import DataLoader -from torchvision.datasets import MNIST - -from pytorch_lightning import _logger as log -from pytorch_lightning.core import LightningModule - - -class LightningTemplateModel(LightningModule): - """ - Sample model to show how to define a template. - - Example: - - >>> # define simple Net for MNIST dataset - >>> params = dict( - ... drop_prob=0.2, - ... batch_size=2, - ... in_features=28 * 28, - ... learning_rate=0.001 * 8, - ... optimizer_name='adam', - ... data_root='./datasets', - ... out_features=10, - ... hidden_dim=1000, - ... ) - >>> from argparse import Namespace - >>> hparams = Namespace(**params) - >>> model = LightningTemplateModel(hparams) - """ - - def __init__(self, hparams): - """ - Pass in hyperparameters as a `argparse.Namespace` or a `dict` to the model. - """ - # init superclass - super().__init__() - self.hparams = hparams - self.c_d1 = nn.Linear(in_features=self.hparams.in_features, - out_features=self.hparams.hidden_dim) - self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) - self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) - - self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, - out_features=self.hparams.out_features) - - def forward(self, x): - """ - No special modification required for Lightning, define it as you normally would - in the `nn.Module` in vanilla PyTorch. - """ - x = self.c_d1(x.view(x.size(0), -1)) - x = torch.tanh(x) - x = self.c_d1_bn(x) - x = self.c_d1_drop(x) - x = self.c_d2(x) - return x - - def training_step(self, batch, batch_idx): - """ - Lightning calls this inside the training loop with the data from the training dataloader - passed in as `batch`. - """ - # forward pass - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - tensorboard_logs = {'train_loss': loss} - return {'loss': loss, 'log': tensorboard_logs} - - def validation_step(self, batch, batch_idx): - """ - Lightning calls this inside the validation loop with the data from the validation dataloader - passed in as `batch`. - """ - x, y = batch - y_hat = self(x) - val_loss = F.cross_entropy(y_hat, y) - labels_hat = torch.argmax(y_hat, dim=1) - n_correct_pred = torch.sum(y == labels_hat).item() - return {'val_loss': val_loss, "n_correct_pred": n_correct_pred, "n_pred": len(x)} - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - test_loss = F.cross_entropy(y_hat, y) - labels_hat = torch.argmax(y_hat, dim=1) - n_correct_pred = torch.sum(y == labels_hat).item() - return {'test_loss': test_loss, "n_correct_pred": n_correct_pred, "n_pred": len(x)} - - def validation_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs. - :param outputs: list of individual outputs of each validation step. - """ - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - val_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs) - tensorboard_logs = {'val_loss': avg_loss, 'val_acc': val_acc} - return {'val_loss': avg_loss, 'log': tensorboard_logs} - - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() - test_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs) - tensorboard_logs = {'test_loss': avg_loss, 'test_acc': test_acc} - return {'test_loss': avg_loss, 'log': tensorboard_logs} - - # --------------------- - # TRAINING SETUP - # --------------------- - def configure_optimizers(self): - """ - Return whatever optimizers and learning rate schedulers you want here. - At least one optimizer is required. - """ - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) - return [optimizer], [scheduler] - - def prepare_data(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - self.mnist_train = MNIST(self.hparams.data_root, train=True, download=True, transform=transform) - self.mnist_test = MNIST(self.hparams.data_root, train=False, download=True, transform=transform) - - def train_dataloader(self): - log.info('Training data loader called.') - return DataLoader(self.mnist_train, batch_size=self.hparams.batch_size, num_workers=4) - - def val_dataloader(self): - log.info('Validation data loader called.') - return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size, num_workers=4) - - def test_dataloader(self): - log.info('Test data loader called.') - return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size, num_workers=4) - - @staticmethod - def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover - """ - Parameters you define here will be available to your model through `self.hparams`. - """ - parser = ArgumentParser(parents=[parent_parser]) - - # param overwrites - # parser.set_defaults(gradient_clip_val=5.0) - - # network params - parser.add_argument('--in_features', default=28 * 28, type=int) - parser.add_argument('--out_features', default=10, type=int) - # use 500 for CPU, 50000 for GPU to see speed difference - parser.add_argument('--hidden_dim', default=50000, type=int) - parser.add_argument('--drop_prob', default=0.2, type=float) - parser.add_argument('--learning_rate', default=0.001, type=float) - - # data - parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str) - - # training params (opt) - parser.add_argument('--epochs', default=20, type=int) - parser.add_argument('--optimizer_name', default='adam', type=str) - parser.add_argument('--batch_size', default=64, type=int) - return parser diff --git a/pl_examples/requirements.txt b/pl_examples/requirements.txt deleted file mode 100644 index 24506bbba7964b..00000000000000 --- a/pl_examples/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torchvision>=0.4.0 -gym>=0.17.0 \ No newline at end of file diff --git a/pl_examples/run_ddp-examples.sh b/pl_examples/run_ddp-examples.sh new file mode 100644 index 00000000000000..6cc36364e397dd --- /dev/null +++ b/pl_examples/run_ddp-examples.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +ARGS_EXTRA_DDP=" --gpus 2 --accelerator ddp" +ARGS_EXTRA_AMP=" --precision 16" + +python pl_examples/basic_examples/simple_image_classifier.py $@ ${ARGS_EXTRA_DDP} +python pl_examples/basic_examples/simple_image_classifier.py $@ ${ARGS_EXTRA_DDP} ${ARGS_EXTRA_AMP} + +python pl_examples/basic_examples/backbone_image_classifier.py $@ ${ARGS_EXTRA_DDP} +python pl_examples/basic_examples/backbone_image_classifier.py $@ ${ARGS_EXTRA_DDP} ${ARGS_EXTRA_AMP} + +python pl_examples/basic_examples/autoencoder.py $@ ${ARGS_EXTRA_DDP} +python pl_examples/basic_examples/autoencoder.py $@ ${ARGS_EXTRA_DDP} ${ARGS_EXTRA_AMP} diff --git a/pl_examples/run_examples-args.sh b/pl_examples/run_examples-args.sh new file mode 100644 index 00000000000000..352869538cb18f --- /dev/null +++ b/pl_examples/run_examples-args.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +echo $@ + +full_path=$(realpath $0) +echo $full_path + +dir_path=$(dirname $full_path) +echo $dir_path + +python ${dir_path}/basic_examples/simple_image_classifier.py $@ + +python ${dir_path}/basic_examples/backbone_image_classifier.py $@ + +python ${dir_path}/basic_examples/autoencoder.py $@ diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py new file mode 100644 index 00000000000000..b930957a26346b --- /dev/null +++ b/pl_examples/test_examples.py @@ -0,0 +1,93 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import platform +from unittest import mock + +import pytest +import torch + +from pl_examples import _DALI_AVAILABLE + +ARGS_DEFAULT = """ +--default_root_dir %(tmpdir)s \ +--max_epochs 1 \ +--batch_size 32 \ +--limit_train_batches 2 \ +--limit_val_batches 2 \ +""" + +ARGS_GPU = ARGS_DEFAULT + """ +--gpus 1 \ +""" + +ARGS_DP = ARGS_DEFAULT + """ +--gpus 2 \ +--accelerator dp \ +""" + +ARGS_AMP = """ +--precision 16 \ +""" + + +@pytest.mark.parametrize( + 'import_cli', [ + 'pl_examples.basic_examples.simple_image_classifier', + 'pl_examples.basic_examples.backbone_image_classifier', + 'pl_examples.basic_examples.autoencoder', + ] +) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.parametrize('cli_args', [ARGS_DP, ARGS_DP + ARGS_AMP]) +def test_examples_dp(tmpdir, import_cli, cli_args): + + module = importlib.import_module(import_cli) + # update the temp dir + cli_args = cli_args % {'tmpdir': tmpdir} + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + module.cli_main() + + +@pytest.mark.parametrize( + 'import_cli', [ + 'pl_examples.basic_examples.simple_image_classifier', + 'pl_examples.basic_examples.backbone_image_classifier', + 'pl_examples.basic_examples.autoencoder', + ] +) +@pytest.mark.parametrize('cli_args', [ARGS_DEFAULT]) +def test_examples_cpu(tmpdir, import_cli, cli_args): + + module = importlib.import_module(import_cli) + # update the temp dir + cli_args = cli_args % {'tmpdir': tmpdir} + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + module.cli_main() + + +@pytest.mark.skipif(not _DALI_AVAILABLE, reason="Nvidia DALI required") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.') +@pytest.mark.parametrize('cli_args', [ARGS_GPU]) +def test_examples_mnist_dali(tmpdir, cli_args): + from pl_examples.basic_examples.dali_image_classifier import cli_main + + # update the temp dir + cli_args = cli_args % {'tmpdir': tmpdir} + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + cli_main() diff --git a/pyproject.toml b/pyproject.toml index 4c3ee6e11f24c9..e8a3213f2b738d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,15 @@ requires = [ "wheel", ] -[tool.autopep8] -max_line_length = 120 -ignore = ["W504", "W504", "E402", "E731", "C40", "E741", "F40", "F841"] +[tool.isort] +known_first_party = [ + "benchmarks", + "docs", + "pl_examples", + "pytorch_lightning", + "tests", +] +profile = "black" +line_length = 120 +force_sort_within_sections = "False" +order_by_type = "False" diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 705bdfc64157ef..b9660475bf2f7d 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -1,72 +1,44 @@ """Root package info.""" -__version__ = '0.7.5' -__author__ = 'William Falcon et al.' -__author_email__ = 'waf2107@columbia.edu' -__license__ = 'Apache-2.0' -__copyright__ = 'Copyright (c) 2018-2020, %s.' % __author__ -__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' -# this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." \ - " Scale your models. Write less boilerplate." -__long_docs__ = """ -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. - It's more of a style-guide than a framework. - -In Lightning, you organize your code into 3 distinct categories: - -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc. this goes in Callbacks). - -Although your research/production project might start simple, once you add things like GPU AND TPU training, - 16-bit precision, etc, you end up spending more time engineering than researching. - Lightning automates AND rigorously tests those parts for you. - -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. - -Documentation -------------- -- https://pytorch-lightning.readthedocs.io/en/latest -- https://pytorch-lightning.readthedocs.io/en/stable -""" - -import logging as python_logging - -_logger = python_logging.getLogger("lightning") -python_logging.basicConfig(level=python_logging.INFO) - -try: - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - __LIGHTNING_SETUP__ -except NameError: - __LIGHTNING_SETUP__ = False - -if __LIGHTNING_SETUP__: - import sys # pragma: no-cover - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: - from pytorch_lightning.core import LightningModule - from pytorch_lightning.trainer import Trainer - from pytorch_lightning.callbacks import Callback - from pytorch_lightning.core import data_loader - - __all__ = [ - 'Trainer', - 'LightningModule', - 'Callback', - 'data_loader' - ] - - # necessary for regular bolts imports. Skip exception since bolts is not always installed - try: - from pytorch_lightning import bolts - except ImportError: - pass - # __call__ = __all__ +import logging +import os + +from pytorch_lightning.info import ( # noqa: F401 + __author__, + __author_email__, + __copyright__, + __docs__, + __homepage__, + __license__, + __version__, +) + +_root_logger = logging.getLogger() +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +# if root logger has handlers, propagate messages up and let root logger process them +if not _root_logger.hasHandlers(): + _logger.addHandler(logging.StreamHandler()) + _logger.propagate = False + +_PACKAGE_ROOT = os.path.dirname(__file__) +_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) + +from pytorch_lightning import metrics # noqa: E402 +from pytorch_lightning.callbacks import Callback # noqa: E402 +from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 +from pytorch_lightning.trainer import Trainer # noqa: E402 +from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 + +__all__ = [ + 'Trainer', + 'LightningDataModule', + 'LightningModule', + 'Callback', + 'seed_everything', + 'metrics', +] # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__) diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py new file mode 100644 index 00000000000000..05e15fe1f17678 --- /dev/null +++ b/pytorch_lightning/accelerators/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.accelerators.accelerator import Accelerator # noqa F401 +from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa F401 +from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa F401 +from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa F401 diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py new file mode 100644 index 00000000000000..7d16d91e3bf824 --- /dev/null +++ b/pytorch_lightning/accelerators/accelerator.py @@ -0,0 +1,513 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union + +import torch +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin +from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.enums import AMPType, LightningEnum + +if TYPE_CHECKING: + from torch.cuda.amp import GradScaler + + from pytorch_lightning.trainer.trainer import Trainer + +_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None] + + +class Accelerator(object): + """ + The Accelerator Base Class. + An Accelerator is meant to deal with one type of Hardware. + + Currently there are accelerators for: + - CPU + - GPU + - TPU + + Each Accelerator gets two plugins upon initialization: + One to handle differences from the training routine and one to handle different precisions. + + """ + + def __init__( + self, + precision_plugin: PrecisionPlugin, + training_type_plugin: TrainingTypePlugin, + ) -> None: + """ + + Args: + precision_plugin: the plugin to handle precision-specific parts + training_type_plugin: the plugin to handle different training routines + """ + self.precision_plugin = precision_plugin + self.training_type_plugin = training_type_plugin + + self.optimizers: Sequence = [] + self.lr_schedulers: Sequence = [] + self.optimizer_frequencies: Sequence = [] + + def connect(self, model: LightningModule) -> None: + """Transfers ownership of the model to this plugin""" + self.training_type_plugin.connect(model) + + def setup_environment(self) -> None: + """ + Setup any processes or distributed connections. + This is called before the LightningModule/DataModule setup hook + which allows the user to access the accelerator environment before setup is complete. + """ + self.training_type_plugin.setup_environment() + + def setup(self, trainer: 'Trainer', model: LightningModule) -> None: + """ + Setup plugins for the trainer fit and creates optimizers. + Args: + trainer: the trainer instance + model: the LightningModule + """ + self.setup_training_type_plugin(self.training_type_plugin, model) + if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.setup_precision_plugin(self.precision_plugin) + + def start_training(self, trainer: 'Trainer') -> None: + self.training_type_plugin.start_training(trainer) + + def start_evaluating(self, trainer: 'Trainer') -> None: + self.training_type_plugin.start_evaluating(trainer) + + def start_predicting(self, trainer: 'Trainer') -> None: + self.training_type_plugin.start_predicting(trainer) + + def pre_dispatch(self, trainer: 'Trainer') -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) + self.precision_plugin.pre_dispatch() + + def post_dispatch(self, trainer: 'Trainer') -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + self.training_type_plugin.post_dispatch() + self.precision_plugin.post_dispatch() + + @property + def model(self) -> torch.nn.Module: + """Returns the model. This can also be a wrapped LightningModule. + For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module` + + """ + return self.training_type_plugin.model + + @model.setter + def model(self, new_model: torch.nn.Module) -> None: + self.training_type_plugin.model = new_model + + @property + def lightning_module(self) -> LightningModule: + """Returns the pure LightningModule. + To get the potentially wrapped model use :attr:`Accelerator.model` + + """ + return self.training_type_plugin.lightning_module + + @property + def root_device(self) -> torch.device: + return self.training_type_plugin.root_device + + def teardown(self) -> None: + """This method is called to teardown the training process. + It is the right place to release memory and free other ressources. + """ + pass + + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: + """Moves the batch to the correct device. + The returned batch is of the same type as the input batch, just having all tensors on the correct device. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + """ + model = self.lightning_module + + if model is not None: + return model._apply_batch_transfer_handler(batch, device) + + return move_data_to_device(batch, device) + + def on_train_start(self) -> None: + """Hook to do something upon the training start""" + pass + + def training_step( + self, + args: List[Union[Any, int]], + ) -> _STEP_OUTPUT_TYPE: + """The actual training step. + + Args: + args: the arguments for the models training step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): Integer displaying index of this batch + optimizer_idx (int): When using multiple optimizers, this argument will also be present. + hiddens(:class:`~torch.Tensor`): Passed in if + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. + + """ + args[0] = self.to_device(args[0]) + + with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context(): + return self.training_type_plugin.training_step(*args) + + def post_training_step(self) -> None: + self.training_type_plugin.post_training_step() + + def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + """The actual validation step. + + Args: + args: the arguments for the models validation step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple val dataloaders used) + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context(): + return self.training_type_plugin.validation_step(*args) + + def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + """The actual test step. + + Args: + args: the arguments for the models test step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple test dataloaders used). + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): + return self.training_type_plugin.test_step(*args) + + def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + """The actual predict step. + + Args: + args: the arguments for the models predict step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple predict dataloaders used). + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): + return self.training_type_plugin.predict_step(*args) + + def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: + """A hook to do something at the end of the training step + + Args: + output: the output of the training step + """ + return self.training_type_plugin.training_step_end(output) + + def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: + """A hook to do something at the end of the test step + + Args: + output: the output of the test step + """ + return self.training_type_plugin.test_step_end(output) + + def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: + """A hook to do something at the end of the validation step + + Args: + output: the output of the validation step + """ + return self.training_type_plugin.validation_step_end(output) + + def backward( + self, + closure_loss: torch.Tensor, + optimizer: Optimizer, + optimizer_idx: int, + should_accumulate: bool, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """Forwards backward-calls to the precision plugin. + + Args: + closure_loss: a tensor holding the loss value to backpropagate + should_accumulate: whether to accumulate gradients + """ + self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, optimizer_idx) + + output = self.precision_plugin.backward( + self.lightning_module, closure_loss, optimizer, optimizer_idx, should_accumulate, *args, **kwargs + ) + + self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, optimizer_idx) + + return output + + def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None: + """performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + opt_idx: index of the current optimizer + lambda_closure: closure calculating the loss value + + """ + make_optimizer_step = self.precision_plugin.pre_optimizer_step( + self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs + ) + if make_optimizer_step: + self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs) + self.precision_plugin.post_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs) + + def run_optimizer_step( + self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any + ) -> None: + self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs) + + def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: + """Zeros all model parameter's gradients""" + model_ref = self.lightning_module + model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + """clips all the optimizer parameters to the given value""" + + self.precision_plugin.clip_gradients(optimizer, clip_val) + + def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: + """Hook to do something on the end of an training epoch + + Args: + outputs: the outputs of the training steps + """ + pass + + def on_train_end(self) -> None: + """Hook to do something at the end of the training""" + pass + + def setup_optimizers(self, trainer: 'Trainer') -> None: + """creates optimizers and schedulers + + Args: + trainer: the Trainer, these optimizers should be connected to + model: the model to be optimized by the created optimizers + """ + if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING): + return + optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( + trainer=trainer, model=self.lightning_module + ) + self.optimizers = optimizers + self.lr_schedulers = lr_schedulers + self.optimizer_frequencies = optimizer_frequencies + + def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """Attaches the training type plugin to the accelerator.""" + plugin.setup(model) + + def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: + """Attaches the precision plugin to the accelerator""" + model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) + self.model = model + self.optimizers = optimizers + self.schedulers = schedulers + + def to_device(self, batch: Any) -> Any: + """Pushes the batch to the root device""" + # Todo (tchaton) Better fix + is_dict = isinstance(batch, dict) + if is_dict: + batch = [batch] + batch = self.batch_to_device(batch, self.root_device) + return batch[0] if is_dict else batch + + @property + def amp_backend(self) -> Optional[LightningEnum]: + if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + return AMPType.APEX + elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + return AMPType.NATIVE + return None + + @property + def precision(self) -> Union[str, int]: + return self.precision_plugin.precision + + @property + def scaler(self) -> Optional['GradScaler']: + + return getattr(self.precision_plugin, 'scaler', None) + + @property + def rpc_enabled(self) -> bool: + return self.training_type_plugin.rpc_enabled + + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]: + """ + Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom + plugins. + """ + return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer) + + def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: + return self.training_type_plugin.on_save(checkpoint) + + def barrier(self, name: Optional[str] = None) -> None: + self.training_type_plugin.barrier(name=name) + + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed. + + Args: + obj: Object to broadcast to all process, usually a tensor or collection of tensors. + src: The source rank of which the object will be broadcast from + """ + return self.training_type_plugin.broadcast(obj, src) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes. + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + Return: + A tensor of shape (world_size, batch, ...) + """ + return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return self.training_type_plugin.process_dataloader(dataloader) + + @property + def results(self) -> Any: + """ + The results of the last run will be cached within the training type plugin. + In distributed training, we make sure to transfer the results to the appropriate master process. + """ + return self.training_type_plugin.results + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + """ + Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + shard the model instantly - useful for extremely large models. Can save memory and + initialization time. + + Returns: Model parallel context. + """ + with self.training_type_plugin.model_sharded_context(): + yield + + # todo: remove in v1.5 + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """ + Attaches the training type plugin to the accelerator. + Also transfers ownership of the model to this plugin + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn( + 'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) + self.setup_training_type_plugin(plugin, model) + + # todo: remove in v1.5 + def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: + """Attaches the precision plugin to the accelerator + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn( + 'Accelerator method `connect_precision_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) + self.setup_precision_plugin(plugin) + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + self.training_type_plugin.save_checkpoint(checkpoint, filepath) + + @property + def call_configure_sharded_model_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self.training_type_plugin.call_configure_sharded_model_hook + + @call_configure_sharded_model_hook.setter + def call_configure_sharded_model_hook(self, mode: bool) -> None: + self.training_type_plugin.call_configure_sharded_model_hook = mode + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return self.training_type_plugin.setup_optimizers_in_pre_dispatch diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py new file mode 100644 index 00000000000000..22ea8f1e1b7aa9 --- /dev/null +++ b/pytorch_lightning/accelerators/cpu.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if TYPE_CHECKING: + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + + +class CPUAccelerator(Accelerator): + + def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If AMP is used with CPU, or if the selected device is not CPU. + """ + if isinstance(self.precision_plugin, MixedPrecisionPlugin): + raise MisconfigurationException("amp + cpu is not supported. Please use a GPU option") + + if "cpu" not in str(self.root_device): + raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead") + + return super().setup(trainer, model) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py new file mode 100644 index 00000000000000..c23960e4fd9e3f --- /dev/null +++ b/pytorch_lightning/accelerators/gpu.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any, TYPE_CHECKING + +import torch + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins import DataParallelPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if TYPE_CHECKING: + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + +_log = logging.getLogger(__name__) + + +class GPUAccelerator(Accelerator): + + def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not GPU. + """ + if "cuda" not in str(self.root_device): + raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") + self.set_nvidia_flags() + torch.cuda.set_device(self.root_device) + return super().setup(trainer, model) + + def on_train_start(self) -> None: + # clear cache before training + # use context because of: + # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 + with torch.cuda.device(self.root_device): + torch.cuda.empty_cache() + + def on_train_end(self) -> None: + # clean up memory + self.model.cpu() + with torch.cuda.device(self.root_device): + torch.cuda.empty_cache() + + @staticmethod + def set_nvidia_flags() -> None: + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) + devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) + _log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def to_device(self, batch: Any) -> Any: + # no need to transfer batch to device in DP mode + # TODO: Add support to allow batch transfer to device in Lightning for DP mode. + if not isinstance(self.training_type_plugin, DataParallelPlugin): + batch = super().to_device(batch) + + return batch diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py new file mode 100644 index 00000000000000..35a475e3e790de --- /dev/null +++ b/pytorch_lightning/accelerators/tpu.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin +from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin +from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _XLA_AVAILABLE: + import torch_xla.core.xla_model as xm + from torch_xla._patched_functions import clip_grad_norm_ + + xla_clip_grad_norm_ = clip_grad_norm_ + +if TYPE_CHECKING: + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + + +class TPUAccelerator(Accelerator): + + def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training. + """ + if isinstance(self.precision_plugin, MixedPrecisionPlugin): + raise MisconfigurationException( + "amp + tpu is not supported. " + "Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin" + ) + + if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): + raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") + return super().setup(trainer, model) + + def run_optimizer_step( + self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any + ) -> None: + xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + # todo: Add support for backward with all_gather + if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: + return xm.all_gather(tensor).view(-1, *tensor.shape) + return tensor + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): + + model = self.lightning_module + parameters = model.parameters() + + grad_clip_val = float(clip_val) + if grad_clip_val <= 0: + return + + max_norm = grad_clip_val + + xla_clip_grad_norm_(parameters, max_norm, norm_type) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 7e8e0ce5bcfef3..fb61ad81aee283 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -1,16 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning +from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler +from pytorch_lightning.callbacks.lambda_function import LambdaCallback +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.callbacks.lr_logger import LearningRateLogger -from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar +from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.pruning import ModelPruning +from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging __all__ = [ + 'BackboneFinetuning', + 'BaseFinetuning', 'Callback', 'EarlyStopping', - 'ModelCheckpoint', + 'GPUStatsMonitor', 'GradientAccumulationScheduler', - 'LearningRateLogger', - 'ProgressBarBase', + 'LambdaCallback', + 'LearningRateMonitor', + 'ModelCheckpoint', + 'ModelPruning', 'ProgressBar', + 'ProgressBarBase', + 'QuantizationAwareTraining', + 'StochasticWeightAveraging', ] diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 50ea061df615ec..768e4ebca30ee8 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -1,87 +1,211 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. r""" -Callback Base -============= - Abstract base class used to build new callbacks. """ import abc +from typing import Any, Dict, List, Optional + +from pytorch_lightning.core.lightning import LightningModule class Callback(abc.ABC): r""" Abstract base class used to build new callbacks. + + Subclass this class and override any of the relevant hooks """ - def on_init_start(self, trainer): + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: + """Called before configure sharded model""" + + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None: + """Called before accelerator is being setup""" + pass + + def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + """Called when fit, validate, test, predict, or tune begins""" + pass + + def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + """Called when fit, validate, test, predict, or tune ends""" + pass + + def on_init_start(self, trainer) -> None: """Called when the trainer initialization begins, model has not yet been set.""" pass - def on_init_end(self, trainer): + def on_init_end(self, trainer) -> None: """Called when the trainer initialization ends, model has not yet been set.""" pass - def on_sanity_check_start(self, trainer, pl_module): + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + """Called when fit begins""" + pass + + def on_fit_end(self, trainer, pl_module: LightningModule) -> None: + """Called when fit ends""" + pass + + def on_sanity_check_start(self, trainer, pl_module: LightningModule) -> None: """Called when the validation sanity check starts.""" pass - def on_sanity_check_end(self, trainer, pl_module): + def on_sanity_check_end(self, trainer, pl_module: LightningModule) -> None: """Called when the validation sanity check ends.""" pass - def on_epoch_start(self, trainer, pl_module): - """Called when the epoch begins.""" + def on_train_batch_start( + self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called when the train batch begins.""" + pass + + def on_train_batch_end( + self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + """Called when the train batch ends.""" + pass + + def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: + """Called when the train epoch begins.""" + pass + + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + """Called when the train epoch ends.""" + pass + + def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: + """Called when the val epoch begins.""" + pass + + def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + """Called when the val epoch ends.""" pass - def on_epoch_end(self, trainer, pl_module): - """Called when the epoch ends.""" + def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: + """Called when the test epoch begins.""" pass - def on_batch_start(self, trainer, pl_module): + def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: + """Called when the test epoch ends.""" + pass + + def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: + """Called when either of train/val/test epoch begins.""" + pass + + def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: + """Called when either of train/val/test epoch ends.""" + pass + + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass - def on_validation_batch_start(self, trainer, pl_module): + def on_validation_batch_start( + self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """Called when the validation batch begins.""" pass - def on_validation_batch_end(self, trainer, pl_module): + def on_validation_batch_end( + self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """Called when the validation batch ends.""" pass - def on_test_batch_start(self, trainer, pl_module): + def on_test_batch_start( + self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """Called when the test batch begins.""" pass - def on_test_batch_end(self, trainer, pl_module): + def on_test_batch_end( + self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """Called when the test batch ends.""" pass - def on_batch_end(self, trainer, pl_module): + def on_batch_end(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch ends.""" pass - def on_train_start(self, trainer, pl_module): + def on_train_start(self, trainer, pl_module: LightningModule) -> None: """Called when the train begins.""" pass - def on_train_end(self, trainer, pl_module): + def on_train_end(self, trainer, pl_module: LightningModule) -> None: """Called when the train ends.""" pass - def on_validation_start(self, trainer, pl_module): + def on_pretrain_routine_start(self, trainer, pl_module: LightningModule) -> None: + """Called when the pretrain routine begins.""" + pass + + def on_pretrain_routine_end(self, trainer, pl_module: LightningModule) -> None: + """Called when the pretrain routine ends.""" + pass + + def on_validation_start(self, trainer, pl_module: LightningModule) -> None: """Called when the validation loop begins.""" pass - def on_validation_end(self, trainer, pl_module): + def on_validation_end(self, trainer, pl_module: LightningModule) -> None: """Called when the validation loop ends.""" pass - def on_test_start(self, trainer, pl_module): + def on_test_start(self, trainer, pl_module: LightningModule) -> None: """Called when the test begins.""" pass - def on_test_end(self, trainer, pl_module): + def on_test_end(self, trainer, pl_module: LightningModule) -> None: """Called when the test ends.""" pass + + def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None: + """Called when the training is interrupted by ``KeyboardInterrupt``.""" + pass + + def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict: + """ + Called when saving a model checkpoint, use to persist state. + + Args: + trainer: the current Trainer instance. + pl_module: the current LightningModule instance. + checkpoint: the checkpoint dictionary that will be saved. + + Returns: + The callback state. + """ + pass + + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + """Called when loading a model checkpoint, use to reload state. + + Args: + callback_state: the callback state returned by ``on_save_checkpoint``. + """ + pass + + def on_after_backward(self, trainer, pl_module: LightningModule) -> None: + """Called after ``loss.backward()`` and before optimizers do anything.""" + pass + + def on_before_zero_grad(self, trainer, pl_module: LightningModule, optimizer) -> None: + """Called after ``optimizer.step()`` and before ``optimizer.zero_grad()``.""" + pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 100c317172044e..24ebcdf8073573 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -1,93 +1,111 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. r""" Early Stopping -============== +^^^^^^^^^^^^^^ -Stop training when a monitored quantity has stopped improving. +Monitor a metric and stop training when it stops improving. """ +from typing import Any, Dict import numpy as np import torch -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn - -torch_inf = torch.tensor(np.Inf) +from pytorch_lightning.utilities.exceptions import MisconfigurationException class EarlyStopping(Callback): r""" + Monitor a metric and stop training when it stops improving. Args: - monitor: quantity to be monitored. Default: ``'val_loss'``. - min_delta: minimum change in the monitored quantity - to qualify as an improvement, i.e. an absolute - change of less than `min_delta`, will count as no - improvement. Default: ``0``. - patience: number of epochs with no improvement - after which training will be stopped. Default: ``0``. - verbose: verbosity mode. Default: ``False``. - mode: one of {auto, min, max}. In `min` mode, - training will stop when the quantity - monitored has stopped decreasing; in `max` - mode it will stop when the quantity - monitored has stopped increasing; in `auto` - mode, the direction is automatically inferred - from the name of the monitored quantity. Default: ``'auto'``. - strict: whether to crash the training if `monitor` is - not found in the metrics. Default: ``True``. + monitor: quantity to be monitored. + min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute + change of less than `min_delta`, will count as no improvement. + patience: number of validation checks with no improvement + after which training will be stopped. Under the default configuration, one validation check happens after + every training epoch. However, the frequency of validation can be modified by setting various parameters on + the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``. + + .. note:: + + It must be noted that the patience parameter counts the number of validation checks with + no improvement, and not the number of training epochs. Therefore, with parameters + ``check_val_every_n_epoch=10`` and ``patience=3``, the trainer will perform at least 40 training + epochs before being stopped. + + verbose: verbosity mode. + mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity + monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity + monitored has stopped increasing. + strict: whether to crash the training if `monitor` is not found in the validation metrics. + + Raises: + MisconfigurationException: + If ``mode`` is none of ``"min"`` or ``"max"``. + RuntimeError: + If the metric ``monitor`` is not available. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') - >>> trainer = Trainer(early_stop_callback=early_stopping) + >>> trainer = Trainer(callbacks=[early_stopping]) """ mode_dict = { 'min': torch.lt, 'max': torch.gt, } - def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3, - verbose: bool = False, mode: str = 'auto', strict: bool = True): + def __init__( + self, + monitor: str = 'early_stop_on', + min_delta: float = 0.0, + patience: int = 3, + verbose: bool = False, + mode: str = 'min', + strict: bool = True, + ): super().__init__() self.monitor = monitor self.patience = patience self.verbose = verbose self.strict = strict self.min_delta = min_delta - self.wait = 0 + self.wait_count = 0 self.stopped_epoch = 0 self.mode = mode - if mode not in self.mode_dict: - if self.verbose > 0: - log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') - self.mode = 'auto' - - if self.mode == 'auto': - if self.monitor == 'acc': - self.mode = 'max' - else: - self.mode = 'min' - if self.verbose > 0: - log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') + if self.mode not in self.mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + torch_inf = torch.tensor(np.Inf) + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf def _validate_condition_metric(self, logs): - """ - Checks that the condition metric for early stopping is good - :param logs: - :return: - """ monitor_val = logs.get(self.monitor) - error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available. Either add `{self.monitor}` to the return of ' - f' validation_epoch end or modify your EarlyStopping callback to use any of the ' - f'following: `{"`, `".join(list(logs.keys()))}`') + + error_msg = ( + f'Early stopping conditioned on metric `{self.monitor}` which is not available.' + ' Pass in or modify your `EarlyStopping` callback to use any of the following:' + f' `{"`, `".join(list(logs.keys()))}`' + ) if monitor_val is None: if self.strict: @@ -103,36 +121,54 @@ def _validate_condition_metric(self, logs): def monitor_op(self): return self.mode_dict[self.mode] - def on_train_start(self, trainer, pl_module): - # Allow instances to be re-used - self.wait = 0 - self.stopped_epoch = 0 - self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + return { + 'wait_count': self.wait_count, + 'stopped_epoch': self.stopped_epoch, + 'best_score': self.best_score, + 'patience': self.patience + } + + def on_load_checkpoint(self, callback_state: Dict[str, Any]): + self.wait_count = callback_state['wait_count'] + self.stopped_epoch = callback_state['stopped_epoch'] + self.best_score = callback_state['best_score'] + self.patience = callback_state['patience'] - def on_epoch_end(self, trainer, pl_module): + def on_validation_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if trainer.state != TrainerState.FITTING or trainer.sanity_checking: + return + + self._run_early_stopping_check(trainer) + + def _run_early_stopping_check(self, trainer): + """ + Checks whether the early stopping condition is met + and if so tells the trainer to stop the training. + """ logs = trainer.callback_metrics - stop_training = False - if not self._validate_condition_metric(logs): - return stop_training + + if ( + trainer.fast_dev_run # disable early_stopping with fast_dev_run + or not self._validate_condition_metric(logs) # short circuit if metric not present + ): + return # short circuit if metric not present current = logs.get(self.monitor) - if not isinstance(current, torch.Tensor): - current = torch.tensor(current) - if self.monitor_op(current - self.min_delta, self.best): - self.best = current - self.wait = 0 + # when in dev debugging + trainer.dev_debugger.track_early_stopping_history(self, current) + + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 else: - self.wait += 1 - if self.wait >= self.patience: - self.stopped_epoch = trainer.current_epoch - stop_training = True - self.on_train_end(trainer, pl_module) + self.wait_count += 1 - return stop_training + if self.wait_count >= self.patience: + self.stopped_epoch = trainer.current_epoch + trainer.should_stop = True - def on_train_end(self, trainer, pl_module): - if self.stopped_epoch > 0 and self.verbose > 0: - rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' - ' but will start from "0" in v0.8.0.', DeprecationWarning) - log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') + # stop every ddp process if any world process decides to stop + trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py new file mode 100644 index 00000000000000..b25e5e06e8b86c --- /dev/null +++ b/pytorch_lightning/callbacks/finetuning.py @@ -0,0 +1,367 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Finetuning Callback +^^^^^^^^^^^^^^^^^^^^ +Freeze and unfreeze models for finetuning purposes +""" +import logging +from typing import Callable, Generator, Iterable, List, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential +from torch.optim.optimizer import Optimizer + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +log = logging.getLogger(__name__) + + +def multiplicative(epoch): + return 2 + + +class BaseFinetuning(Callback): + r""" + + This class implements the base logic for writing your own Finetuning Callback. + + Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic. + + ``freeze_before_training``: This method is called before ``configure_optimizers`` + and should be used to freeze any modules parameters. + + ``finetune_function``: This method is called on every train epoch start and should be used to + ``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group`` + within the optimizer. + + .. note:: Make sure to filter the parameters based on ``requires_grad``. + + Example:: + + class MyModel(LightningModule) + + ... + + def configure_optimizer(self): + # Make sure to filter the parameters based on `requires_grad` + return Adam(filter(lambda p: p.requires_grad, self.parameters)) + + class FeatureExtractorFreezeUnfreeze(BaseFinetuning): + + def __init__(self, unfreeze_at_epoch=10) + self._unfreeze_at_epoch = unfreeze_at_epoch + + def freeze_before_training(self, pl_module): + # freeze any module you want + # Here, we are freezing ``feature_extractor`` + self.freeze(pl_module.feature_extractor) + + def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): + # When `current_epoch` is 10, feature_extractor will start training. + if current_epoch == self._unfreeze_at_epoch: + self.unfreeze_and_add_param_group( + module=pl_module.feature_extractor, + optimizer=optimizer, + train_bn=True, + ) + """ + + @staticmethod + def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: + """ + This function is used to flatten a module or an iterable of modules into a list of its modules. + + Args: + modules: A given module or an iterable of modules + + Returns: + List of modules + """ + if isinstance(modules, Iterable): + _modules = [] + for m in modules: + _modules.extend(BaseFinetuning.flatten_modules(m)) + + else: + _modules = modules.modules() + + return list( + filter( + lambda m: not isinstance(m, (Container, Sequential, ModuleDict, ModuleList, LightningModule)), _modules + ) + ) + + @staticmethod + def filter_params( + modules: Union[Module, Iterable[Union[Module, Iterable]]], + train_bn: bool = True, + requires_grad: bool = True + ) -> Generator: + """Yields the `requires_grad` parameters of a given module or list of modules. + + Args: + modules: A given module or an iterable of modules + train_bn: Whether to train BatchNorm module + requires_grad: Whether to create a generator for trainable or non-trainable parameters. + + Returns: + Generator + """ + modules = BaseFinetuning.flatten_modules(modules) + for mod in modules: + if isinstance(mod, _BatchNorm) and not train_bn: + continue + for param in mod.parameters(): + if param.requires_grad == requires_grad: + yield param + + @staticmethod + def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None: + """ + Unfreezes the parameters of the provided modules + + Args: + modules: A given module or an iterable of modules + """ + modules = BaseFinetuning.flatten_modules(modules) + for module in modules: + for param in module.parameters(): + param.requires_grad = True + + @staticmethod + def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: + """ + Freezes the parameters of the provided modules + + Args: + modules: A given module or an iterable of modules + train_bn: If True, leave the BatchNorm layers in training mode + + Returns: + None + """ + modules = BaseFinetuning.flatten_modules(modules) + for mod in modules: + if isinstance(mod, _BatchNorm) and train_bn: + BaseFinetuning.make_trainable(mod) + else: + for param in mod.parameters(): + param.requires_grad = False + + @staticmethod + def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: + """ + This function is used to exclude any parameter which already exists in + this optimizer + + Args: + optimizer: Optimizer used for parameter exclusion + params: Iterable of parameters used to check against the provided optimizer + + Returns: + List of parameters not contained in this optimizer param groups + """ + out_params = [] + removed_params = [] + for param in params: + if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]): + out_params.append(param) + else: + removed_params.append(param) + + if removed_params: + rank_zero_warn( + "The provided params to be freezed already exist within another group of this optimizer." + " Those parameters will be skipped.\n" + "HINT: Did you init your optimizer in `configure_optimizer` as such:\n" + f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ", UserWarning + ) + return out_params + + @staticmethod + def unfreeze_and_add_param_group( + modules: Union[Module, Iterable[Union[Module, Iterable]]], + optimizer: Optimizer, + lr: Optional[float] = None, + initial_denom_lr: float = 10., + train_bn: bool = True, + ) -> None: + """ + Unfreezes a module and adds its parameters to an optimizer. + + Args: + + modules: A module or iterable of modules to unfreeze. + Their parameters will be added to an optimizer as a new param group. + + optimizer: The provided optimizer will receive new parameters and will add them to + `add_param_group` + + lr: Learning rate for the new param group. + + initial_denom_lr: If no lr is provided, the learning from the first param group will be used + and divided by initial_denom_lr. + + train_bn: Whether to train the BatchNormalization layers. + + Returns: + None + """ + BaseFinetuning.make_trainable(modules) + params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr) + denom_lr = initial_denom_lr if lr is None else 1. + params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True) + params = BaseFinetuning.filter_on_optimizer(optimizer, params) + if params: + optimizer.add_param_group({ + 'params': params, + 'lr': params_lr / denom_lr, + }) + + def on_before_accelerator_backend_setup(self, trainer, pl_module): + self.freeze_before_training(pl_module) + + def on_train_epoch_start(self, trainer, pl_module): + """Called when the epoch begins.""" + for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): + self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) + + def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + """ + Override to add your unfreeze logic + """ + raise NotImplementedError + + def freeze_before_training(self, pl_module: LightningModule): + """ + Override to add your freeze logic + """ + raise NotImplementedError + + +class BackboneFinetuning(BaseFinetuning): + r""" + + Finetune a backbone model based on a learning rate user-defined scheduling. + When the backbone learning rate reaches the current model learning rate + and ``should_align`` is set to True, it will align with it for the rest of the training. + + Args: + + unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed. + + lambda_func: Scheduling function for increasing backbone learning rate. + + backbone_initial_ratio_lr: + Used to scale down the backbone learning rate compared to rest of model + + backbone_initial_lr: Optional, Inital learning rate for the backbone. + By default, we will use current_learning / backbone_initial_ratio_lr + + should_align: Wheter to align with current learning rate when backbone learning + reaches it. + + initial_denom_lr: When unfreezing the backbone, the intial learning rate will + current_learning_rate / initial_denom_lr. + + train_bn: Wheter to make Batch Normalization trainable. + + verbose: Display current learning rate for model and backbone + + round: Precision for displaying learning rate + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import BackboneFinetuning + >>> multiplicative = lambda epoch: 1.5 + >>> backbone_finetuning = BackboneFinetuning(200, multiplicative) + >>> trainer = Trainer(callbacks=[backbone_finetuning]) + + """ + + def __init__( + self, + unfreeze_backbone_at_epoch: int = 10, + lambda_func: Callable = multiplicative, + backbone_initial_ratio_lr: float = 10e-2, + backbone_initial_lr: Optional[float] = None, + should_align: bool = True, + initial_denom_lr: float = 10., + train_bn: bool = True, + verbose: bool = False, + round: int = 12, + ): + self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch + self.backbone_initial_lr = backbone_initial_lr + self.lambda_func = lambda_func + self.backbone_initial_ratio_lr = backbone_initial_ratio_lr + self.should_align = should_align + self.initial_denom_lr = initial_denom_lr + self.train_bn = train_bn + self.round = round + self.verbose = verbose + + def on_fit_start(self, trainer, pl_module): + """ + Raises: + MisconfigurationException: + If LightningModule has no nn.Module `backbone` attribute. + """ + if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module): + return + raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute") + + def freeze_before_training(self, pl_module: LightningModule): + self.freeze(pl_module.backbone) + + def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + """Called when the epoch begins.""" + + if epoch == self.unfreeze_backbone_at_epoch: + current_lr = optimizer.param_groups[0]['lr'] + initial_backbone_lr = self.backbone_initial_lr if self.backbone_initial_lr is not None \ + else current_lr * self.backbone_initial_ratio_lr + self.previous_backbone_lr = initial_backbone_lr + self.unfreeze_and_add_param_group( + pl_module.backbone, + optimizer, + initial_backbone_lr, + train_bn=self.train_bn, + initial_denom_lr=self.initial_denom_lr + ) + if self.verbose: + log.info( + f"Current lr: {round(current_lr, self.round)}, " + f"Backbone lr: {round(initial_backbone_lr, self.round)}" + ) + + elif epoch > self.unfreeze_backbone_at_epoch: + current_lr = optimizer.param_groups[0]['lr'] + next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr + next_current_backbone_lr = current_lr if (self.should_align and next_current_backbone_lr > current_lr) \ + else next_current_backbone_lr + optimizer.param_groups[-1]["lr"] = next_current_backbone_lr + self.previous_backbone_lr = next_current_backbone_lr + if self.verbose: + log.info( + f"Current lr: {round(current_lr, self.round)}, " + f"Backbone lr: {round(next_current_backbone_lr, self.round)}" + ) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py new file mode 100644 index 00000000000000..ace69b02348428 --- /dev/null +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -0,0 +1,214 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPU Stats Monitor +================= + +Monitor and logs GPU stats during training. + +""" + +import os +import shutil +import subprocess +import time +from typing import Dict, List, Tuple + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import DeviceType, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import AttributeDict + + +class GPUStatsMonitor(Callback): + r""" + Automatically monitors and logs GPU stats during training stage. ``GPUStatsMonitor`` + is a callback and in order to use it you need to assign a logger in the ``Trainer``. + + Args: + memory_utilization: Set to ``True`` to monitor used, free and percentage of memory + utilization at the start and end of each step. Default: ``True``. + gpu_utilization: Set to ``True`` to monitor percentage of GPU utilization + at the start and end of each step. Default: ``True``. + intra_step_time: Set to ``True`` to monitor the time of each step. Default: ``False``. + inter_step_time: Set to ``True`` to monitor the time between the end of one step + and the start of the next step. Default: ``False``. + fan_speed: Set to ``True`` to monitor percentage of fan speed. Default: ``False``. + temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius. + Default: ``False``. + + Raises: + MisconfigurationException: + If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import GPUStatsMonitor + >>> gpu_stats = GPUStatsMonitor() # doctest: +SKIP + >>> trainer = Trainer(callbacks=[gpu_stats]) # doctest: +SKIP + + GPU stats are mainly based on `nvidia-smi --query-gpu` command. The description of the queries is as follows: + + - **fan.speed** – The fan speed value is the percent of maximum speed that the device's fan is currently + intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed. + If the fan is physically blocked and unable to spin, this output will not match the actual fan speed. + Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure. + - **memory.used** – Total memory allocated by active contexts. + - **memory.free** – Total free memory. + - **utilization.gpu** – Percent of time over the past sample period during which one or more kernels was + executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product. + - **utilization.memory** – Percent of time over the past sample period during which global (device) memory was + being read or written. The sample period may be between 1 second and 1/6 second depending on the product. + - **temperature.gpu** – Core GPU temperature, in degrees C. + - **temperature.memory** – HBM memory temperature, in degrees C. + + """ + + def __init__( + self, + memory_utilization: bool = True, + gpu_utilization: bool = True, + intra_step_time: bool = False, + inter_step_time: bool = False, + fan_speed: bool = False, + temperature: bool = False + ): + super().__init__() + + if shutil.which('nvidia-smi') is None: + raise MisconfigurationException( + 'Cannot use GPUStatsMonitor callback because NVIDIA driver is not installed.' + ) + + self._log_stats = AttributeDict({ + 'memory_utilization': memory_utilization, + 'gpu_utilization': gpu_utilization, + 'intra_step_time': intra_step_time, + 'inter_step_time': inter_step_time, + 'fan_speed': fan_speed, + 'temperature': temperature + }) + + def on_train_start(self, trainer, *args, **kwargs): + if not trainer.logger: + raise MisconfigurationException('Cannot use GPUStatsMonitor callback with Trainer that has no logger.') + + if trainer._device_type != DeviceType.GPU: + raise MisconfigurationException( + 'You are using GPUStatsMonitor but are not running on GPU' + f' since gpus attribute in Trainer is set to {trainer.gpus}.' + ) + + self._gpu_ids = ','.join(map(str, trainer.data_parallel_device_ids)) + + def on_train_epoch_start(self, *args, **kwargs): + self._snap_intra_step_time = None + self._snap_inter_step_time = None + + @rank_zero_only + def on_train_batch_start(self, trainer, *args, **kwargs): + if self._log_stats.intra_step_time: + self._snap_intra_step_time = time.time() + + if not self._should_log(trainer): + return + + gpu_stat_keys = self._get_gpu_stat_keys() + gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) + logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) + + if self._log_stats.inter_step_time and self._snap_inter_step_time: + # First log at beginning of second step + logs['batch_time/inter_step (ms)'] = (time.time() - self._snap_inter_step_time) * 1000 + + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_batch_end(self, trainer, *args, **kwargs): + if self._log_stats.inter_step_time: + self._snap_inter_step_time = time.time() + + if not self._should_log(trainer): + return + + gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys() + gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys]) + logs = self._parse_gpu_stats(self._gpu_ids, gpu_stats, gpu_stat_keys) + + if self._log_stats.intra_step_time and self._snap_intra_step_time: + logs['batch_time/intra_step (ms)'] = (time.time() - self._snap_intra_step_time) * 1000 + + trainer.logger.log_metrics(logs, step=trainer.global_step) + + def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]: + """Run nvidia-smi to get the gpu stats""" + gpu_query = ','.join(queries) + format = 'csv,nounits,noheader' + result = subprocess.run( + [shutil.which('nvidia-smi'), f'--query-gpu={gpu_query}', f'--format={format}', f'--id={self._gpu_ids}'], + encoding="utf-8", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + check=True + ) + + def _to_float(x: str) -> float: + try: + return float(x) + except ValueError: + return 0. + + stats = result.stdout.strip().split(os.linesep) + stats = [[_to_float(x) for x in s.split(', ')] for s in stats] + return stats + + @staticmethod + def _parse_gpu_stats(gpu_ids: str, stats: List[List[float]], keys: List[Tuple[str, str]]) -> Dict[str, float]: + """Parse the gpu stats into a loggable dict""" + logs = {} + for i, gpu_id in enumerate(gpu_ids.split(',')): + for j, (x, unit) in enumerate(keys): + logs[f'gpu_id: {gpu_id}/{x} ({unit})'] = stats[i][j] + return logs + + def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: + """Get the GPU stats keys""" + stat_keys = [] + + if self._log_stats.gpu_utilization: + stat_keys.append(('utilization.gpu', '%')) + + if self._log_stats.memory_utilization: + stat_keys.extend([('memory.used', 'MB'), ('memory.free', 'MB'), ('utilization.memory', '%')]) + + return stat_keys + + def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: + """Get the device stats keys""" + stat_keys = [] + + if self._log_stats.fan_speed: + stat_keys.append(('fan.speed', '%')) + + if self._log_stats.temperature: + stat_keys.extend([('temperature.gpu', '°C'), ('temperature.memory', '°C')]) + + return stat_keys + + @staticmethod + def _should_log(trainer) -> bool: + should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) + + return should_log diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index bc1cd79e96f633..b1885087f4da0b 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -1,13 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. r""" Gradient Accumulator ==================== Change gradient accumulation factor according to scheduling. +Trainer also calls ``optimizer.step()`` for the last indivisible step number. """ +from typing import Dict + from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn class GradientAccumulationScheduler(Callback): @@ -17,9 +32,12 @@ class GradientAccumulationScheduler(Callback): Args: scheduling: scheduling in format {epoch: accumulation_factor} - .. warning:: - Epochs indexing starts from "1" until v0.6.x, - but will start from "0" in v0.8.0. + Raises: + TypeError: + If ``scheduling`` is an empty ``dict``, + or not all keys and values of ``scheduling`` are integers. + IndexError: + If ``minimal_epoch`` is less than 0. Example:: @@ -34,7 +52,7 @@ class GradientAccumulationScheduler(Callback): >>> trainer = Trainer(accumulate_grad_batches={5: 2}) """ - def __init__(self, scheduling: dict): + def __init__(self, scheduling: Dict[int, int]): super().__init__() if not scheduling: # empty dict error @@ -45,20 +63,19 @@ def __init__(self, scheduling: dict): raise TypeError("All epoches and accumulation factor must be integers") minimal_epoch = min(scheduling.keys()) - # rank_zero_warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' - # ' but will start from "0" in v0.8.0.', DeprecationWarning) - if minimal_epoch < 1: + if minimal_epoch < 0: raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct") - if minimal_epoch != 1: # if user didnt define first epoch accumulation factor - scheduling.update({1: 1}) + if minimal_epoch != 0: # if user didnt define first epoch accumulation factor + scheduling.update({0: 1}) self.scheduling = scheduling self.epochs = sorted(scheduling.keys()) - def on_epoch_start(self, trainer, pl_module): - # indexing epochs from 1 (until v0.6.x) - # In v0.8.0, ` + 1` should be removed. - epoch = trainer.current_epoch + 1 + def going_to_accumulate_grad_batches(self): + return any([v > 1 for v in self.scheduling.values()]) + + def on_train_epoch_start(self, trainer, pl_module): + epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py new file mode 100644 index 00000000000000..a7485814b1b17d --- /dev/null +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -0,0 +1,160 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Lambda Callback +^^^^^^^^^^^^^^^ + +Create a simple callback on the fly using lambda functions. + +""" + +from typing import Callable, Optional + +from pytorch_lightning.callbacks.base import Callback + + +class LambdaCallback(Callback): + r""" + Create a simple callback on the fly using lambda functions. + + Args: + **kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback` + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LambdaCallback + >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))]) + """ + + def __init__( + self, + on_before_accelerator_backend_setup: Optional[Callable] = None, + setup: Optional[Callable] = None, + on_configure_sharded_model: Optional[Callable] = None, + teardown: Optional[Callable] = None, + on_init_start: Optional[Callable] = None, + on_init_end: Optional[Callable] = None, + on_fit_start: Optional[Callable] = None, + on_fit_end: Optional[Callable] = None, + on_sanity_check_start: Optional[Callable] = None, + on_sanity_check_end: Optional[Callable] = None, + on_train_batch_start: Optional[Callable] = None, + on_train_batch_end: Optional[Callable] = None, + on_train_epoch_start: Optional[Callable] = None, + on_train_epoch_end: Optional[Callable] = None, + on_validation_epoch_start: Optional[Callable] = None, + on_validation_epoch_end: Optional[Callable] = None, + on_test_epoch_start: Optional[Callable] = None, + on_test_epoch_end: Optional[Callable] = None, + on_epoch_start: Optional[Callable] = None, + on_epoch_end: Optional[Callable] = None, + on_batch_start: Optional[Callable] = None, + on_validation_batch_start: Optional[Callable] = None, + on_validation_batch_end: Optional[Callable] = None, + on_test_batch_start: Optional[Callable] = None, + on_test_batch_end: Optional[Callable] = None, + on_batch_end: Optional[Callable] = None, + on_train_start: Optional[Callable] = None, + on_train_end: Optional[Callable] = None, + on_pretrain_routine_start: Optional[Callable] = None, + on_pretrain_routine_end: Optional[Callable] = None, + on_validation_start: Optional[Callable] = None, + on_validation_end: Optional[Callable] = None, + on_test_start: Optional[Callable] = None, + on_test_end: Optional[Callable] = None, + on_keyboard_interrupt: Optional[Callable] = None, + on_save_checkpoint: Optional[Callable] = None, + on_load_checkpoint: Optional[Callable] = None, + on_after_backward: Optional[Callable] = None, + on_before_zero_grad: Optional[Callable] = None, + ): + if on_before_accelerator_backend_setup is not None: + self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup + if setup is not None: + self.setup = setup + if on_configure_sharded_model is not None: + self.on_configure_sharded_model = on_configure_sharded_model + if teardown is not None: + self.teardown = teardown + if on_init_start is not None: + self.on_init_start = on_init_start + if on_init_end is not None: + self.on_init_end = on_init_end + if on_fit_start is not None: + self.on_fit_start = on_fit_start + if on_fit_end is not None: + self.on_fit_end = on_fit_end + if on_sanity_check_start is not None: + self.on_sanity_check_start = on_sanity_check_start + if on_sanity_check_end is not None: + self.on_sanity_check_end = on_sanity_check_end + if on_train_batch_start is not None: + self.on_train_batch_start = on_train_batch_start + if on_train_batch_end is not None: + self.on_train_batch_end = on_train_batch_end + if on_train_epoch_start is not None: + self.on_train_epoch_start = on_train_epoch_start + if on_train_epoch_end is not None: + self.on_train_epoch_end = on_train_epoch_end + if on_validation_epoch_start is not None: + self.on_validation_epoch_start = on_validation_epoch_start + if on_validation_epoch_end is not None: + self.on_validation_epoch_end = on_validation_epoch_end + if on_test_epoch_start is not None: + self.on_test_epoch_start = on_test_epoch_start + if on_test_epoch_end is not None: + self.on_test_epoch_end = on_test_epoch_end + if on_epoch_start is not None: + self.on_epoch_start = on_epoch_start + if on_epoch_end is not None: + self.on_epoch_end = on_epoch_end + if on_batch_start is not None: + self.on_batch_start = on_batch_start + if on_validation_batch_start is not None: + self.on_validation_batch_start = on_validation_batch_start + if on_validation_batch_end is not None: + self.on_validation_batch_end = on_validation_batch_end + if on_test_batch_start is not None: + self.on_test_batch_start = on_test_batch_start + if on_test_batch_end is not None: + self.on_test_batch_end = on_test_batch_end + if on_batch_end is not None: + self.on_batch_end = on_batch_end + if on_train_start is not None: + self.on_train_start = on_train_start + if on_train_end is not None: + self.on_train_end = on_train_end + if on_pretrain_routine_start is not None: + self.on_pretrain_routine_start = on_pretrain_routine_start + if on_pretrain_routine_end is not None: + self.on_pretrain_routine_end = on_pretrain_routine_end + if on_validation_start is not None: + self.on_validation_start = on_validation_start + if on_validation_end is not None: + self.on_validation_end = on_validation_end + if on_test_start is not None: + self.on_test_start = on_test_start + if on_test_end is not None: + self.on_test_end = on_test_end + if on_keyboard_interrupt is not None: + self.on_keyboard_interrupt = on_keyboard_interrupt + if on_save_checkpoint is not None: + self.on_save_checkpoint = on_save_checkpoint + if on_load_checkpoint is not None: + self.on_load_checkpoint = on_load_checkpoint + if on_after_backward is not None: + self.on_after_backward = on_after_backward + if on_before_zero_grad is not None: + self.on_before_zero_grad = on_before_zero_grad diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py deleted file mode 100755 index 6ad68905bc3417..00000000000000 --- a/pytorch_lightning/callbacks/lr_logger.py +++ /dev/null @@ -1,118 +0,0 @@ -r""" - -Logging of learning rates -========================= - -Log learning rate for lr schedulers during training - -""" - -from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class LearningRateLogger(Callback): - r""" - Automatically logs learning rate for learning rate schedulers during training. - - Example:: - - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.callbacks import LearningRateLogger - >>> lr_logger = LearningRateLogger() - >>> trainer = Trainer(callbacks=[lr_logger]) - - Logging names are automatically determined based on optimizer class name. - In case of multiple optimizers of same type, they will be named `Adam`, - `Adam-1` etc. If a optimizer has multiple parameter groups they will - be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a - `name` keyword in the construction of the learning rate schdulers - - Example:: - - def configure_optimizer(self): - optimizer = torch.optim.Adam(...) - lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...) - 'name': 'my_logging_name'} - return [optimizer], [lr_scheduler] - """ - def __init__(self): - self.lrs = None - self.lr_sch_names = [] - - def on_train_start(self, trainer, pl_module): - """ Called before training, determines unique names for all lr - schedulers in the case of multiple of the same type or in - the case of multiple parameter groups - """ - if trainer.lr_schedulers == []: - raise MisconfigurationException( - 'Cannot use LearningRateLogger callback with models that have no' - ' learning rate schedulers. Please see documentation for' - ' `configure_optimizers` method.') - - if not trainer.logger: - raise MisconfigurationException( - 'Cannot use LearningRateLogger callback with Trainer that has no logger.') - - # Find names for schedulers - names = self._find_names(trainer.lr_schedulers) - - # Initialize for storing values - self.lrs = dict.fromkeys(names, []) - - def on_batch_start(self, trainer, pl_module): - latest_stat = self._extract_lr(trainer, 'step') - if trainer.logger and latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) - - def on_epoch_start(self, trainer, pl_module): - latest_stat = self._extract_lr(trainer, 'epoch') - if trainer.logger and latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) - - def _extract_lr(self, trainer, interval): - """ Extracts learning rates for lr schedulers and saves information - into dict structure. """ - latest_stat = {} - for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): - if scheduler['interval'] == interval: - param_groups = scheduler['scheduler'].optimizer.param_groups - if len(param_groups) != 1: - for i, pg in enumerate(param_groups): - lr, key = pg['lr'], f'{name}/{i + 1}' - self.lrs[key].append(lr) - latest_stat[key] = lr - else: - self.lrs[name].append(param_groups[0]['lr']) - latest_stat[name] = param_groups[0]['lr'] - return latest_stat - - def _find_names(self, lr_schedulers): - # Create uniqe names in the case we have multiple of the same learning - # rate schduler + multiple parameter groups - names = [] - for scheduler in lr_schedulers: - sch = scheduler['scheduler'] - if 'name' in scheduler: - name = scheduler['name'] - else: - opt_name = 'lr-' + sch.optimizer.__class__.__name__ - i, name = 1, opt_name - # Multiple schduler of the same type - while True: - if name not in names: - break - i, name = i + 1, f'{opt_name}-{i}' - - # Multiple param groups for the same schduler - param_groups = sch.optimizer.param_groups - if len(param_groups) != 1: - for i, pg in enumerate(param_groups): - temp = name + '/pg' + str(i + 1) - names.append(temp) - else: - names.append(name) - - self.lr_sch_names.append(name) - return names diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py new file mode 100644 index 00000000000000..7530bfaa9d21e5 --- /dev/null +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -0,0 +1,207 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" + +Learning Rate Monitor +===================== + +Monitor and logs learning rate for lr schedulers during training. + +""" + +from typing import Dict, List, Optional + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class LearningRateMonitor(Callback): + r""" + Automatically monitor and logs learning rate for learning rate schedulers during training. + + Args: + logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers + at the same interval, set to ``None`` to log at individual interval + according to the ``interval`` key of each scheduler. Defaults to ``None``. + log_momentum: option to also log the momentum values of the optimizer, if the optimizer + has the ``momentum`` or ``betas`` attribute. Defaults to ``False``. + + Raises: + MisconfigurationException: + If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LearningRateMonitor + >>> lr_monitor = LearningRateMonitor(logging_interval='step') + >>> trainer = Trainer(callbacks=[lr_monitor]) + + Logging names are automatically determined based on optimizer class name. + In case of multiple optimizers of same type, they will be named ``Adam``, + ``Adam-1`` etc. If a optimizer has multiple parameter groups they will + be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a + ``name`` keyword in the construction of the learning rate schdulers + + Example:: + + def configure_optimizer(self): + optimizer = torch.optim.Adam(...) + lr_scheduler = { + 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...) + 'name': 'my_logging_name' + } + return [optimizer], [lr_scheduler] + + """ + + def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): + if logging_interval not in (None, 'step', 'epoch'): + raise MisconfigurationException('logging_interval should be `step` or `epoch` or `None`.') + + self.logging_interval = logging_interval + self.log_momentum = log_momentum + self.lrs = None + self.lr_sch_names = [] + + def on_train_start(self, trainer, *args, **kwargs): + """ + Called before training, determines unique names for all lr + schedulers in the case of multiple of the same type or in + the case of multiple parameter groups + + Raises: + MisconfigurationException: + If ``Trainer`` has no ``logger``. + """ + if not trainer.logger: + raise MisconfigurationException( + 'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.' + ) + + if not trainer.lr_schedulers: + rank_zero_warn( + 'You are using `LearningRateMonitor` callback with models that' + ' have no learning rate schedulers. Please see documentation' + ' for `configure_optimizers` method.', RuntimeWarning + ) + + if self.log_momentum: + + def _check_no_key(key): + return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers) + + if _check_no_key('momentum') and _check_no_key('betas'): + rank_zero_warn( + "You have set log_momentum=True, but some optimizers do not" + " have momentum. This will log a value 0 for the momentum.", RuntimeWarning + ) + + # Find names for schedulers + names = self._find_names(trainer.lr_schedulers) + + # Initialize for storing values + self.lrs = {name: [] for name in names} + self.last_momentum_values = {name + "-momentum": None for name in names} + + def on_train_batch_start(self, trainer, *args, **kwargs): + if not self._should_log(trainer): + return + + if self.logging_interval != 'epoch': + interval = 'step' if self.logging_interval is None else 'any' + latest_stat = self._extract_stats(trainer, interval) + + if latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def on_train_epoch_start(self, trainer, *args, **kwargs): + if self.logging_interval != 'step': + interval = 'epoch' if self.logging_interval is None else 'any' + latest_stat = self._extract_stats(trainer, interval) + + if latest_stat: + trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + + def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: + latest_stat = {} + + for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): + if scheduler['interval'] == interval or interval == 'any': + opt = scheduler['scheduler'].optimizer + param_groups = opt.param_groups + use_betas = 'betas' in opt.defaults + + for i, pg in enumerate(param_groups): + suffix = f'/pg{i + 1}' if len(param_groups) > 1 else '' + lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}') + latest_stat.update(lr) + momentum = self._extract_momentum( + param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas + ) + latest_stat.update(momentum) + + return latest_stat + + def _extract_lr(self, param_group, name: str) -> Dict[str, float]: + lr = param_group.get('lr') + self.lrs[name].append(lr) + return {name: lr} + + def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]: + if not self.log_momentum: + return {} + + momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0) + self.last_momentum_values[name] = momentum + return {name: momentum} + + def _find_names(self, lr_schedulers) -> List[str]: + # Create uniqe names in the case we have multiple of the same learning + # rate schduler + multiple parameter groups + names = [] + for scheduler in lr_schedulers: + sch = scheduler['scheduler'] + if scheduler['name'] is not None: + name = scheduler['name'] + else: + opt_name = 'lr-' + sch.optimizer.__class__.__name__ + i, name = 1, opt_name + + # Multiple schduler of the same type + while True: + if name not in names: + break + i, name = i + 1, f'{opt_name}-{i}' + + # Multiple param groups for the same schduler + param_groups = sch.optimizer.param_groups + + if len(param_groups) != 1: + for i, pg in enumerate(param_groups): + temp = f'{name}/pg{i + 1}' + names.append(temp) + else: + names.append(name) + + self.lr_sch_names.append(name) + + return names + + @staticmethod + def _should_log(trainer) -> bool: + should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) + + return should_log diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 71f893e080fdf3..5f0318e7ac8d1f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Model Checkpointing =================== @@ -5,265 +18,713 @@ Automatically save model checkpoints during training. """ - +import logging import os import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Union import numpy as np -from typing import Optional - import torch -from pytorch_lightning import _logger as log +import yaml + from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache + +log = logging.getLogger(__name__) +warning_cache = WarningCache() class ModelCheckpoint(Callback): r""" - Save the model after every epoch. + Save the model after every epoch by monitoring a quantity. + + After training finishes, use :attr:`best_model_path` to retrieve the path to the + best checkpoint file and :attr:`best_model_score` to retrieve its score. Args: - filepath: path to save the model file. - Can contain named formatting options to be auto-filled. + dirpath: directory to save the model file. Example:: # custom path - # saves a file like: my/path/epoch_0.ckpt - >>> checkpoint_callback = ModelCheckpoint('my/path/') + # saves a file like: my/path/epoch=0-step=10.ckpt + >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + + By default, dirpath is ``None`` and will be set at runtime to the location + specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments, + and if the Trainer uses a logger, the path will also contain logger name and version. + + filename: checkpoint filename. Can contain named formatting options to be auto-filled. + + Example:: # save any arbitrary metrics like `val_loss`, etc. in name - # saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt + # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt >>> checkpoint_callback = ModelCheckpoint( - ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' + ... dirpath='my/path', + ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - Can also be set to `None`, then it will be set to default location - during trainer construction. - - monitor: quantity to monitor. + By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. + monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch. verbose: verbosity mode. Default: ``False``. - save_top_k: if `save_top_k == k`, + save_last: When ``True``, always saves the model at the end of the epoch to + a file `last.ckpt`. Default: ``None``. + save_top_k: if ``save_top_k == k``, the best k models according to the quantity monitored will be saved. if ``save_top_k == 0``, no models are saved. if ``save_top_k == -1``, all models are saved. - Please note that the monitors are checked every `period` epochs. + Please note that the monitors are checked every ``period`` epochs. if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be - appended with a version count starting with `v0`. - mode: one of {auto, min, max}. - If ``save_top_k != 0``, the decision - to overwrite the current save file is made - based on either the maximization or the - minimization of the monitored quantity. For `val_acc`, - this should be `max`, for `val_loss` this should - be `min`, etc. In `auto` mode, the direction is - automatically inferred from the name of the monitored quantity. + appended with a version count starting with ``v1``. + mode: one of {min, max}. + If ``save_top_k != 0``, the decision to overwrite the current save file is made + based on either the maximization or the minimization of the monitored quantity. + For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). + every_n_train_steps: Number of training steps between checkpoints. + If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative. + This must be mutually exclusive with ``every_n_val_epochs``. + every_n_val_epochs: Number of validation epochs between checkpoints. + If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end + To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``every_n_train_steps``. + Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and + ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` + will only save checkpoints at epochs 0 < E <= N + where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E. period: Interval (number of epochs) between checkpoints. + .. warning:: + This argument has been deprecated in v1.3 and will be removed in v1.5. + + Use ``every_n_val_epochs`` instead. + + Note: + For extra customization, ModelCheckpoint includes the following attributes: + + - ``CHECKPOINT_JOIN_CHAR = "-"`` + - ``CHECKPOINT_NAME_LAST = "last"`` + - ``FILE_EXTENSION = ".ckpt"`` + - ``STARTING_VERSION = 1`` + + For example, you can change the default last checkpoint name by doing + ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"`` + + Raises: + MisconfigurationException: + If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``, + if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or + if ``mode`` is none of ``"min"`` or ``"max"``. + ValueError: + If ``trainer.save_checkpoint`` is ``None``. + Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint - # saves checkpoints to 'my/path/' whenever 'val_loss' has a new min - >>> checkpoint_callback = ModelCheckpoint(filepath='my/path/') - >>> trainer = Trainer(checkpoint_callback=checkpoint_callback) + # saves checkpoints to 'my/path/' at every epoch + >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name - # saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt + # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( - ... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}' + ... monitor='val_loss', + ... dirpath='my/path/', + ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) + # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard + # or Neptune, due to the presence of characters like '=' or '/') + # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt + >>> checkpoint_callback = ModelCheckpoint( + ... monitor='val/loss', + ... dirpath='my/path/', + ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', + ... auto_insert_metric_name=False + ... ) + + # retrieve the best checkpoint after training + checkpoint_callback = ModelCheckpoint(dirpath='my/path/') + trainer = Trainer(callbacks=[checkpoint_callback]) + model = ... + trainer.fit(model) + checkpoint_callback.best_model_path + """ - def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, - save_top_k: int = 1, save_weights_only: bool = False, - mode: str = 'auto', period: int = 1, prefix: str = ''): + CHECKPOINT_JOIN_CHAR = "-" + CHECKPOINT_NAME_LAST = "last" + FILE_EXTENSION = ".ckpt" + STARTING_VERSION = 1 + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + monitor: Optional[str] = None, + verbose: bool = False, + save_last: Optional[bool] = None, + save_top_k: Optional[int] = None, + save_weights_only: bool = False, + mode: str = "min", + auto_insert_metric_name: bool = True, + every_n_train_steps: Optional[int] = None, + every_n_val_epochs: Optional[int] = None, + period: Optional[int] = None, + ): super().__init__() - if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: - rank_zero_warn( - f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." - "All files in this directory will be deleted when a checkpoint is saved!" - ) - self._rank = 0 - self.monitor = monitor self.verbose = verbose - if filepath is None: # will be determined by trainer at runtime - self.dirpath, self.filename = None, None - else: - if os.path.isdir(filepath): - self.dirpath, self.filename = filepath, '{epoch}' - else: - self.dirpath, self.filename = os.path.split(filepath) - os.makedirs(self.dirpath, exist_ok=True) + self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = period - self.epoch_last_check = None - self.prefix = prefix + self.auto_insert_metric_name = auto_insert_metric_name + self._last_global_step_saved = -1 + self.current_score = None self.best_k_models = {} - # {filename: monitor} - self.kth_best_model = '' - self.best = 0 + self.kth_best_model_path = "" + self.best_model_score = None + self.best_model_path = "" + self.last_model_path = "" self.save_function = None + self.__init_monitor_mode(monitor, mode) + self.__init_ckpt_dir(dirpath, filename, save_top_k) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) + self.__validate_init_configuration() + + def on_pretrain_routine_start(self, trainer, pl_module): + """ + When pretrain routine starts we build the ckpt dir on the fly + """ + self.__resolve_ckpt_dir(trainer) + self.save_function = trainer.save_checkpoint + + def on_train_batch_end(self, trainer, *args, **kwargs) -> None: + """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ + if self._should_skip_saving_checkpoint(trainer): + return + step = trainer.global_step + skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) + if skip_batch: + return + self.save_checkpoint(trainer) + + def on_validation_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the val loop + """ + skip = ( + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 + ) + if skip: + return + self.save_checkpoint(trainer) + + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + return { + "monitor": self.monitor, + "best_model_score": self.best_model_score, + "best_model_path": self.best_model_path, + "current_score": self.current_score, + "dirpath": self.dirpath + } + + def on_load_checkpoint(self, callback_state: Dict[str, Any]): + self.best_model_score = callback_state["best_model_score"] + self.best_model_path = callback_state["best_model_path"] + + def save_checkpoint(self, trainer, unused: Optional = None): + """ + Performs the main logic around saving a checkpoint. + This method runs on all ranks, it is the responsibility of `self.save_function` + to handle correct behaviour in distributed training, i.e., saving only on rank 0. + """ + if unused is not None: + rank_zero_deprecation( + "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter" + " has been removed. Support for the old signature will be removed in v1.5" + ) + + global_step = trainer.global_step + + self._add_backward_monitor_support(trainer) + self._validate_monitor_key(trainer) + + # track epoch when ckpt was last checked + self._last_global_step_saved = global_step + + # what can be monitored + monitor_candidates = self._monitor_candidates(trainer) + + # callback supports multiple simultaneous modes + # here we call each mode sequentially + # Mode 1: save the top k checkpoints + self._save_top_k_checkpoint(trainer, monitor_candidates) + # Mode 2: save monitor=None checkpoints + self._save_none_monitor_checkpoint(trainer, monitor_candidates) + # Mode 3: save last checkpoints + self._save_last_checkpoint(trainer, monitor_candidates) + + def _should_skip_saving_checkpoint(self, trainer) -> bool: + from pytorch_lightning.trainer.states import TrainerState + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check + or self._last_global_step_saved == trainer.global_step # already saved at the last step + ) + + def __validate_init_configuration(self): + if self.save_top_k is not None and self.save_top_k < -1: + raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self._every_n_train_steps < 0: + raise MisconfigurationException( + f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' + ) + if self._every_n_val_epochs < 0: + raise MisconfigurationException( + f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' + ) + if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: + raise MisconfigurationException( + f'Invalid values for every_n_train_steps={self._every_n_train_steps}' + ' and every_n_val_epochs={self._every_n_val_epochs}.' + ' Both cannot be enabled at the same time.' + ) + if self.monitor is None: + # None: save last epoch, -1: save all epochs, 0: nothing is saved + if self.save_top_k not in (None, -1, 0): + raise MisconfigurationException( + f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid' + ' configuration. No quantity for top_k to track.' + ) + if self.save_last: + rank_zero_warn( + 'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.' + ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).' + ) + if self.save_top_k == -1 and self.save_last: + rank_zero_info( + 'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)' + ' will duplicate the last checkpoint saved.' + ) + + def __init_ckpt_dir(self, dirpath, filename, save_top_k): + + self._fs = get_filesystem(str(dirpath) if dirpath else '') + + if ( + save_top_k is not None and save_top_k > 0 and dirpath is not None and self._fs.isdir(dirpath) + and len(self._fs.ls(dirpath)) > 0 + ): + rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") + + if dirpath and self._fs.protocol == 'file': + dirpath = os.path.realpath(dirpath) + + self.dirpath: Union[str, None] = dirpath or None + self.filename = filename or None + + def __init_monitor_mode(self, monitor, mode): torch_inf = torch.tensor(np.Inf) mode_dict = { - 'min': (torch_inf, 'min'), - 'max': (-torch_inf, 'max'), - 'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') - else (torch_inf, 'min'), + "min": (torch_inf, "min"), + "max": (-torch_inf, "max"), } if mode not in mode_dict: - rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, ' - f'fallback to auto mode.', RuntimeWarning) - mode = 'auto' + raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}") self.kth_value, self.mode = mode_dict[mode] - def _del_model(self, filepath): - if os.path.isfile(filepath): - os.remove(filepath) + def __init_triggers( + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_val_epochs is set + if every_n_train_steps is None and every_n_val_epochs is None: + self._every_n_val_epochs = 1 + self._every_n_train_steps = 0 + log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + else: + self._every_n_val_epochs = every_n_val_epochs or 0 + self._every_n_train_steps = every_n_train_steps or 0 + + # period takes precedence over every_n_val_epochs for backwards compatibility + if period is not None: + rank_zero_deprecation( + 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.' + ) + self._every_n_val_epochs = period + + self._period = self._every_n_val_epochs + + @property + def period(self) -> Optional[int]: + rank_zero_deprecation( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.' + ) + return self._period + + @period.setter + def period(self, value: Optional[int]) -> None: + rank_zero_deprecation( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.' + ) + self._period = value + + @rank_zero_only + def _del_model(self, filepath: str): + if self._fs.exists(filepath): + self._fs.rm(filepath) + log.debug(f"Removed checkpoint: {filepath}") + + def _save_model(self, trainer, filepath: str): + if trainer.training_type_plugin.rpc_enabled: + # RPCPlugin manages saving all model states + # TODO: the rpc plugin should wrap trainer.save_checkpoint + # instead of us having to do it here manually + trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath) + else: + self._do_save(trainer, filepath) + + def _do_save(self, trainer, filepath: str): + # in debugging, track when we save checkpoints + trainer.dev_debugger.track_checkpointing_history(filepath) - def _save_model(self, filepath): # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + if trainer.is_global_zero: + self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) - # delegate the saving to the model + # delegate the saving to the trainer if self.save_function is not None: - self.save_function(filepath) + self.save_function(filepath, self.save_weights_only) else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current): + def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: + if current is None: + return False + + if self.save_top_k == -1: + return True + less_than_k_models = len(self.best_k_models) < self.save_top_k if less_than_k_models: return True if not isinstance(current, torch.Tensor): + rank_zero_warn( + f"{current} is supposed to be a `torch.Tensor`. Saving checkpoint may not work correctly." + f" HINT: check the value of {self.monitor} in your validation loop", + RuntimeWarning, + ) current = torch.tensor(current) - monitor_op = { - "min": torch.lt, - "max": torch.gt, - }[self.mode] + monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] + should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + + # If using multiple devices, make sure all processes are unanimous on the decision. + should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + + return should_update_best_and_save + + @classmethod + def _format_checkpoint_name( + cls, + filename: Optional[str], + epoch: int, + step: int, + metrics: Dict[str, Any], + prefix: str = "", + auto_insert_metric_name: bool = True + ) -> str: + if not filename: + # filename is not set, use default name + filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}" + + # check and parse user passed keys in the string + groups = re.findall(r"(\{.*?)[:\}]", filename) + if len(groups) >= 0: + metrics.update({"epoch": epoch, 'step': step}) + for group in groups: + name = group[1:] + + if auto_insert_metric_name: + filename = filename.replace(group, name + "={" + name) + + if name not in metrics: + metrics[name] = 0 + filename = filename.format(**metrics) + + if prefix: + filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) - return monitor_op(current, self.best_k_models[self.kth_best_model]) + return filename - def format_checkpoint_name(self, epoch, metrics, ver=None): + def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None) -> str: """Generate a filename according to the defined template. Example:: >>> tmpdir = os.path.dirname(__file__) - >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') + >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={})) 'epoch=0.ckpt' - >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}')) - >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') + >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={})) 'epoch=005.ckpt' - >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}')) - >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' - >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}')) - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, + ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', + ... auto_insert_metric_name=False) + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) + 'epoch=2-validation_loss=0.12.ckpt' + >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') + >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' + >>> ckpt = ModelCheckpoint(filename='{step}') + >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) + 'step=0.ckpt' + + """ + filename = self._format_checkpoint_name( + self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name + ) + + if ver is not None: + filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) + + ckpt_name = f"{filename}{self.FILE_EXTENSION}" + return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name + + def __resolve_ckpt_dir(self, trainer): """ - # check if user passed in keys to the string - groups = re.findall(r'(\{.*?)[:\}]', self.filename) + Determines model checkpoint save directory at runtime. References attributes from the + trainer's logger to determine where to save checkpoints. + The base path for saving weights is set in this priority: - if len(groups) == 0: - # default name - filename = f'{self.prefix}_ckpt_epoch_{epoch}' + 1. Checkpoint callback's path (if passed in) + 2. The default_root_dir from trainer if trainer has no logger + 3. The weights_save_path from trainer, if user provides it + 4. User provided weights_saved_path + + The base path gets extended with logger name and version (if these are available) + and subfolder "checkpoints". + """ + # Todo: required argument `pl_module` is not used + if self.dirpath is not None: + return # short circuit + + if trainer.logger is not None: + if trainer.weights_save_path != trainer.default_root_dir: + # the user has changed weights_save_path, it overrides anything + save_dir = trainer.weights_save_path + else: + save_dir = trainer.logger.save_dir or trainer.default_root_dir + + version = ( + trainer.logger.version + if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}" + ) + + version, name = trainer.training_type_plugin.broadcast((version, trainer.logger.name)) + + ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") else: - metrics['epoch'] = epoch - filename = self.filename - for tmp in groups: - name = tmp[1:] - filename = filename.replace(tmp, name + '={' + name) - if name not in metrics: - metrics[name] = 0 - filename = filename.format(**metrics) - str_ver = f'_v{ver}' if ver is not None else '' - filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt') + ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") + + self.dirpath = ckpt_path + + if not trainer.fast_dev_run and trainer.is_global_zero: + self._fs.makedirs(self.dirpath, exist_ok=True) + + def _add_backward_monitor_support(self, trainer): + metrics = trainer.logger_connector.callback_metrics + deprecation_warning = False + + if self.monitor is None and 'val_loss' in metrics: + self.monitor = 'val_loss' + deprecation_warning = True + + if self.save_top_k is None and self.monitor is not None: + # TODO: Remove `Optional` from `save_top_k` when this is deleted in v1.4 + self.save_top_k = 1 + + if deprecation_warning: + warning_cache.warn( + "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2" + " and will be removed in v1.4. Please, create your own `mc = ModelCheckpoint(monitor='your_monitor')`" + " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning + ) + + def _validate_monitor_key(self, trainer): + metrics = trainer.logger_connector.callback_metrics + + # validate metric + if self.monitor is not None and not self._is_valid_monitor_key(metrics): + m = ( + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f" {list(metrics.keys())}. " + f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?" + ) + raise MisconfigurationException(m) + + def _get_metric_interpolated_filepath_name( + self, + monitor_candidates: Dict[str, Any], + epoch: int, + step: int, + trainer, + del_filepath: Optional[str] = None, + ) -> str: + filepath = self.format_checkpoint_name(epoch, step, monitor_candidates) + + version_cnt = self.STARTING_VERSION + while self.file_exists(filepath, trainer) and filepath != del_filepath: + filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version_cnt) + version_cnt += 1 + return filepath - @rank_zero_only - def on_validation_end(self, trainer, pl_module): - # only run on main process - if trainer.proc_rank != 0: + def _monitor_candidates(self, trainer): + monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics) + monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) + return monitor_candidates + + def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if not self.save_last: return - metrics = trainer.callback_metrics - epoch = trainer.current_epoch - if self.save_top_k == 0: - # no models are saved + filepath = self._format_checkpoint_name( + self.CHECKPOINT_NAME_LAST, + trainer.current_epoch, + trainer.global_step, + monitor_candidates, + ) + filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") + + self._save_model(trainer, filepath) + + if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero: + self._del_model(self.last_model_path) + + self.last_model_path = filepath + + def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if self.monitor is None or self.save_top_k == 0: return - if self.epoch_last_check is not None and (epoch - self.epoch_last_check) < self.period: - # skipping in this term + + current = monitor_candidates.get(self.monitor) + epoch = monitor_candidates.get("epoch") + step = monitor_candidates.get("step") + + if self.check_monitor_top_k(trainer, current): + self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) + elif self.verbose: + rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") + + def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if self.monitor is not None or self.save_top_k == 0: return - self.epoch_last_check = epoch + filepath = self._get_metric_interpolated_filepath_name( + monitor_candidates, + trainer.current_epoch, + trainer.global_step, + trainer, + ) + self._save_model(trainer, filepath) - filepath = self.format_checkpoint_name(epoch, metrics) - version_cnt = 0 - while os.path.isfile(filepath): - filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) - # this epoch called before - version_cnt += 1 + if ( + self.save_top_k is None and self.best_model_path and self.best_model_path != filepath + and trainer.is_global_zero + ): + self._del_model(self.best_model_path) - if self.save_top_k != -1: - current = metrics.get(self.monitor) + self.best_model_path = filepath - if current is None: - rank_zero_warn( - f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning - ) - elif self.check_monitor_top_k(current): - self._do_check_save(filepath, current, epoch) - elif self.verbose > 0: - log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}') + def _is_valid_monitor_key(self, metrics): + return self.monitor in metrics or len(metrics) == 0 - else: - if self.verbose > 0: - log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') - self._save_model(filepath) + def _update_best_and_save( + self, current: torch.Tensor, epoch: int, step: int, trainer, monitor_candidates: Dict[str, Any] + ): + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k + + del_filepath = None + if len(self.best_k_models) == k and k > 0: + del_filepath = self.kth_best_model_path + self.best_k_models.pop(del_filepath) - def _do_check_save(self, filepath, current, epoch): - # remove kth + # do not save nan, replace with +/- inf + if isinstance(current, torch.Tensor) and torch.isnan(current): + current = torch.tensor(float('inf' if self.mode == "min" else '-inf')) - del_list = [] - if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0: - delpath = self.kth_best_model - self.best_k_models.pop(self.kth_best_model) - del_list.append(delpath) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, step, trainer, del_filepath) + # save the current score + self.current_score = current self.best_k_models[filepath] = current - if len(self.best_k_models) == self.save_top_k: + + if len(self.best_k_models) == k: # monitor dict has reached k elements - _op = max if self.mode == 'min' else min - self.kth_best_model = _op(self.best_k_models, - key=self.best_k_models.get) - self.kth_value = self.best_k_models[self.kth_best_model] - - _op = min if self.mode == 'min' else max - self.best = _op(self.best_k_models.values()) - - if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor} reached' - f' {current:0.5f} (best {self.best:0.5f}), saving model to' - f' {filepath} as top {self.save_top_k}') - self._save_model(filepath) - - for cur_path in del_list: - if cur_path != filepath: - self._del_model(cur_path) + _op = max if self.mode == "min" else min + self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model_path] + + _op = min if self.mode == "min" else max + self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) + self.best_model_score = self.best_k_models[self.best_model_path] + + if self.verbose: + rank_zero_info( + f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" + f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' + ) + self._save_model(trainer, filepath) + + if del_filepath is not None and filepath != del_filepath: + self._del_model(del_filepath) + + def to_yaml(self, filepath: Optional[Union[str, Path]] = None): + """ + Saves the `best_k_models` dict containing the checkpoint + paths with the corresponding scores to a YAML file. + """ + best_k = {k: v.item() for k, v in self.best_k_models.items()} + if filepath is None: + filepath = os.path.join(self.dirpath, "best_k_models.yaml") + with self._fs.open(filepath, "w") as fp: + yaml.dump(best_k, fp) + + def file_exists(self, filepath: Union[str, Path], trainer) -> bool: + """ + Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing + the internal state to diverge between ranks. + """ + exists = self._fs.exists(filepath) + return trainer.training_type_plugin.broadcast(exists) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index a770c6c9d95e7d..7dc4202530d04a 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Progress Bars ============= @@ -5,12 +18,46 @@ Use or override one of the progress bar callbacks. """ +import importlib +import io +import os import sys -from tqdm.auto import tqdm +# check if ipywidgets is installed before importing tqdm.auto +# to ensure it won't fail and a progress bar is displayed +from typing import Optional, Union + +if importlib.util.find_spec('ipywidgets') is not None: + from tqdm.auto import tqdm as _tqdm +else: + from tqdm import tqdm as _tqdm from pytorch_lightning.callbacks import Callback +_PAD_SIZE = 5 + + +class tqdm(_tqdm): + """ + Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering + """ + + @staticmethod + def format_num(n) -> str: + """ Add additional padding to the formatted numbers """ + should_be_padded = isinstance(n, (float, str)) + if not isinstance(n, str): + n = _tqdm.format_num(n) + if should_be_padded and 'e' not in n: + if '.' not in n and len(n) < _PAD_SIZE: + try: + _ = float(n) + except ValueError: + return n + n += '.' + n += "0" * (_PAD_SIZE - len(n)) + return n + class ProgressBarBase(Callback): r""" @@ -29,8 +76,8 @@ def __init__(self): def disable(self): self.enable = False - def on_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) # don't forget this :) + def on_train_batch_end(self, trainer, pl_module, outputs): + super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') @@ -39,12 +86,14 @@ def on_batch_end(self, trainer, pl_module): trainer = Trainer(callbacks=[bar]) """ + def __init__(self): self._trainer = None self._train_batch_idx = 0 self._val_batch_idx = 0 self._test_batch_idx = 0 + self._predict_batch_idx = 0 @property def trainer(self): @@ -74,6 +123,14 @@ def test_batch_idx(self) -> int: """ return self._test_batch_idx + @property + def predict_batch_idx(self) -> int: + """ + The current batch index being processed during predicting. + Use this to update your progress bar. + """ + return self._predict_batch_idx + @property def total_train_batches(self) -> int: """ @@ -81,37 +138,38 @@ def total_train_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """ - total_train_batches = 1 if self.trainer.fast_dev_run else self.trainer.num_training_batches - return total_train_batches + return self.trainer.num_training_batches @property def total_val_batches(self) -> int: """ - The total number of training batches during validation, which may change from epoch to epoch. + The total number of validation batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ - trainer = self.trainer total_val_batches = 0 - if trainer.fast_dev_run: - total_val_batches = len(trainer.val_dataloaders) - elif not self.trainer.disable_validation: - is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 - total_val_batches = trainer.num_val_batches if is_val_epoch else 0 + if not self.trainer.disable_validation: + is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0 + total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches @property def total_test_batches(self) -> int: """ - The total number of training batches during testing, which may change from epoch to epoch. + The total number of testing batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """ - if self.trainer.fast_dev_run: - total_test_batches = len(self.trainer.test_dataloaders) - else: - total_test_batches = self.trainer.num_test_batches - return total_test_batches + return sum(self.trainer.num_test_batches) + + @property + def total_predict_batches(self) -> int: + """ + The total number of predicting batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + predict dataloader is of infinite size. + """ + return sum(self.trainer.num_predict_batches) def disable(self): """ @@ -125,35 +183,47 @@ def enable(self): """ You should provide a way to enable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training - routines like the `learning rate finder `_ to temporarily enable and - disable the main progress bar. + routines like the :ref:`learning rate finder ` + to temporarily enable and disable the main progress bar. """ raise NotImplementedError + def print(self, *args, **kwargs): + """ + You should provide a way to print without breaking the progress bar. + """ + print(*args, **kwargs) + def on_init_end(self, trainer): self._trainer = trainer def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.batch_idx - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._train_batch_idx += 1 def on_validation_start(self, trainer, pl_module): self._val_batch_idx = 0 - def on_validation_batch_end(self, trainer, pl_module): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._val_batch_idx += 1 def on_test_start(self, trainer, pl_module): self._test_batch_idx = 0 - def on_test_batch_end(self, trainer, pl_module): + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._test_batch_idx += 1 + def on_predict_start(self, trainer, pl_module): + self._predict_batch_idx = 0 + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._predict_batch_idx += 1 + class ProgressBar(ProgressBarBase): r""" @@ -202,6 +272,7 @@ def init_validation_tqdm(self): :class:`~pytorch_lightning.trainer.trainer.Trainer`. """ + def __init__(self, refresh_rate: int = 1, process_position: int = 0): super().__init__() self._refresh_rate = refresh_rate @@ -267,11 +338,27 @@ def init_train_tqdm(self) -> tqdm: ) return bar + def init_predict_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for predicting. """ + bar = tqdm( + desc='Predicting', + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + # The main progress bar doesn't exist in `trainer.validate()` + has_main_bar = self.main_progress_bar is not None bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -282,7 +369,7 @@ def init_validation_tqdm(self) -> tqdm: def init_test_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for testing. """ bar = tqdm( - desc='Testing', + desc="Testing", position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -294,7 +381,6 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): @@ -306,39 +392,43 @@ def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_epoch_start(self, trainer, pl_module): - super().on_epoch_start(trainer, pl_module) + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches - if total_train_batches != float('inf') and not trainer.fast_dev_run: + if total_train_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(convert_inf(total_batches)) - self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') + reset(self.main_progress_bar, total_batches) + self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') - def on_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) - if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: - self.main_progress_bar.update(self.refresh_rate) - self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): + self._update_bar(self.main_progress_bar) + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.total = convert_inf(self.total_val_batches) + if trainer.sanity_checking: + reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) + else: + self._update_bar(self.main_progress_bar) # fill up remaining + self.val_progress_bar = self.init_validation_tqdm() + reset(self.val_progress_bar, self.total_val_batches) - def on_validation_batch_end(self, trainer, pl_module): - super().on_validation_batch_end(trainer, pl_module) - if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0: - self.val_progress_bar.update(self.refresh_rate) - self.main_progress_bar.update(self.refresh_rate) + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.val_batch_idx, self.total_val_batches): + self._update_bar(self.val_progress_bar) + self._update_bar(self.main_progress_bar) def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): @@ -350,18 +440,68 @@ def on_test_start(self, trainer, pl_module): self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) - def on_test_batch_end(self, trainer, pl_module): - super().on_test_batch_end(trainer, pl_module) - if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0: - self.test_progress_bar.update(self.refresh_rate) + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.test_batch_idx, self.total_test_batches): + self._update_bar(self.test_progress_bar) def on_test_end(self, trainer, pl_module): super().on_test_end(trainer, pl_module) self.test_progress_bar.close() + def on_predict_start(self, trainer, pl_module): + super().on_predict_start(trainer, pl_module) + self.predict_progress_bar = self.init_predict_tqdm() + self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.predict_batch_idx, self.total_predict_batches): + self._update_bar(self.predict_progress_bar) -def convert_inf(x): + def on_predict_end(self, trainer, pl_module): + self.predict_progress_bar.close() + + def print( + self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False + ): + active_progress_bar = None + + if not self.main_progress_bar.disable: + active_progress_bar = self.main_progress_bar + elif not self.val_progress_bar.disable: + active_progress_bar = self.val_progress_bar + elif not self.test_progress_bar.disable: + active_progress_bar = self.test_progress_bar + + if active_progress_bar is not None: + s = sep.join(map(str, args)) + active_progress_bar.write(s, end=end, file=file, nolock=nolock) + + def _should_update(self, current, total): + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + + def _update_bar(self, bar: Optional[tqdm]) -> None: + """ Updates the bar by the refresh rate without overshooting. """ + if bar is None: + return + if bar.total is not None: + delta = min(self.refresh_rate, bar.total - bar.n) + else: + # infinite / unknown size + delta = self.refresh_rate + if delta > 0: + bar.update(delta) + + +def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """ The tqdm doesn't support inf values. We have to convert it to None. """ if x == float('inf'): return None return x + + +def reset(bar: tqdm, total: Optional[int] = None) -> None: + """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ + if not bar.disable: + bar.reset(total=convert_inf(total)) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py new file mode 100644 index 00000000000000..3f82ab35654035 --- /dev/null +++ b/pytorch_lightning/callbacks/pruning.py @@ -0,0 +1,456 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +ModelPruning +^^^^^^^^^^^^ +""" +import inspect +import logging +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.utils.prune as pytorch_prune +from torch import nn + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +log = logging.getLogger(__name__) + +_PYTORCH_PRUNING_FUNCTIONS = { + "ln_structured": pytorch_prune.ln_structured, + "l1_unstructured": pytorch_prune.l1_unstructured, + "random_structured": pytorch_prune.random_structured, + "random_unstructured": pytorch_prune.random_unstructured, +} + +_PYTORCH_PRUNING_METHOD = { + "ln_structured": pytorch_prune.LnStructured, + "l1_unstructured": pytorch_prune.L1Unstructured, + "random_structured": pytorch_prune.RandomStructured, + "random_unstructured": pytorch_prune.RandomUnstructured, +} + +_PARAM_TUPLE = Tuple[nn.Module, str] +_PARAM_LIST = Union[List[_PARAM_TUPLE], Tuple[_PARAM_TUPLE]] +_MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) + + +class ModelPruning(Callback): + PARAMETER_NAMES = ("weight", "bias") + + def __init__( + self, + pruning_fn: Union[Callable, str], + parameters_to_prune: Optional[_PARAM_LIST] = None, + parameter_names: Optional[List[str]] = None, + use_global_unstructured: bool = True, + amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, + apply_pruning: Union[bool, Callable[[int], bool]] = True, + make_pruning_permanent: bool = True, + use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, + resample_parameters: bool = False, + pruning_dim: Optional[int] = None, + pruning_norm: Optional[int] = None, + verbose: int = 0, + ) -> None: + """ + Model pruning Callback, using PyTorch's prune utilities. + This callback is responsible of pruning networks parameters during training. + + To learn more about pruning with PyTorch, please take a look at + `this tutorial `_. + + .. warning:: ``ModelPruning`` is in beta and subject to change. + + .. code-block:: python + + parameters_to_prune = [ + (model.mlp_1, "weight"), + (model.mlp_2, "weight") + ] + + trainer = Trainer(callbacks=[ + ModelPruning( + pruning_fn='l1_unstructured', + parameters_to_prune=parameters_to_prune, + amount=0.01, + use_global_unstructured=True, + ) + ]) + + When ``parameters_to_prune`` is ``None``, ``parameters_to_prune`` will contain all parameters from the model. + The user can override ``filter_parameters_to_prune`` to filter any ``nn.Module`` to be pruned. + + Args: + + pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass. + Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details. + + parameters_to_prune: List of tuples ``(nn.Module, "parameter_name_string")``. + + parameter_names: List of parameter names to be pruned from the nn.Module. + Can either be ``"weight"`` or ``"bias"``. + + use_global_unstructured: Whether to apply pruning globally on the model. + If ``parameters_to_prune`` is provided, global unstructured will be restricted on them. + + amount: Quantity of parameters to prune: + + - ``float``. Between 0.0 and 1.0. Represents the fraction of parameters to prune. + - ``int``. Represents the absolute number of parameters to prune. + - ``Callable``. For dynamic values. Will be called every epoch. Should return a value. + + apply_pruning: Whether to apply pruning. + + - ``bool``. Always apply it or not. + - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. + + make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks + when training ends or the model is saved. + + use_lottery_ticket_hypothesis: See `The lottery ticket hypothesis `_: + + - ``bool``. Whether to apply it or not. + - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. + + resample_parameters: Used with ``use_lottery_ticket_hypothesis``. If True, the model parameters will + be resampled, otherwise, the exact original parameters will be used. + + pruning_dim: If you are using a structured pruning method you need to specify the dimension. + + pruning_norm: If you are using ``ln_structured`` you need to specify the norm. + + verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity + + Raises: + MisconfigurationException: + If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``, + if the provided ``pruning_fn`` is not supported, + if ``pruning_dim`` is not provided when ``"unstructured"``, + if ``pruning_norm`` is not provided when ``"ln_structured"``, + if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or + if ``amount`` is none of ``int``, ``float`` and ``Callable``. + """ + + self._use_global_unstructured = use_global_unstructured + self._parameters_to_prune = parameters_to_prune + self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis + self._resample_parameters = resample_parameters + self._parameter_names = parameter_names or self.PARAMETER_NAMES + self._global_kwargs = {} + self._original_layers = None + self._pruning_fn_name = None + + for name in self._parameter_names: + if name not in self.PARAMETER_NAMES: + raise MisconfigurationException( + f"The provided `parameter_names` name: {name} isn't in {self.PARAMETER_NAMES}" + ) + + if isinstance(pruning_fn, str): + pruning_kwargs = {} + pruning_fn = pruning_fn.lower() + if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS: + raise MisconfigurationException( + f"The provided `pruning_fn` {pruning_fn} isn't available in PyTorch's" + f" built-in functions: {list(_PYTORCH_PRUNING_FUNCTIONS.keys())} " + ) + if pruning_fn.endswith("_structured"): + if pruning_dim is None: + raise MisconfigurationException( + "When requesting `structured` pruning, the `pruning_dim` should be provided." + ) + if pruning_fn == "ln_structured": + if pruning_norm is None: + raise MisconfigurationException( + "When requesting `ln_structured` pruning, the `pruning_norm` should be provided." + ) + pruning_kwargs["n"] = pruning_norm + pruning_kwargs["dim"] = pruning_dim + pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) + elif self._is_pruning_method(pruning_fn): + if not use_global_unstructured: + raise MisconfigurationException( + "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`." + ) + else: + raise MisconfigurationException( + f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}" + f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}." + " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance" + ) + + if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": + raise MisconfigurationException( + 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' + f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " + ) + + self.pruning_fn = pruning_fn + self._apply_pruning = apply_pruning + self._make_pruning_permanent = make_pruning_permanent + + if not isinstance(amount, (int, float, Callable)): + raise MisconfigurationException( + "`amount` should be provided and be either an int, a float or Callable function." + ) + + self.amount = amount + + if verbose not in (0, 1, 2): + raise MisconfigurationException("`verbose` must be any of (0, 1, 2)") + + self._verbose = verbose + + def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]: + """ + This function can be overridden to control which module to prune. + """ + return parameters_to_prune + + def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytorch_prune.BasePruningMethod]: + """ + This function takes `pruning_fn`, a function name. + + IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` + ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. + + """ + if self._use_global_unstructured: + pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn] + self._global_kwargs = kwargs + else: + pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] + # save the function __name__ now because partial does not include it + # and there are issues setting the attribute manually in ddp. + self._pruning_fn_name = pruning_fn.__name__ + if self._use_global_unstructured: + return pruning_fn + return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs) + + @staticmethod + def _wrap_pruning_fn(pruning_fn, **kwargs): + return partial(pruning_fn, **kwargs) + + def make_pruning_permanent(self, pl_module: LightningModule): + """ + Removes pruning buffers from any pruned modules + + Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 + """ + for _, module in pl_module.named_modules(): + for k in list(module._forward_pre_hooks): + hook = module._forward_pre_hooks[k] + if isinstance(hook, pytorch_prune.BasePruningMethod): + hook.remove(module) + del module._forward_pre_hooks[k] + + def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): + trained = getattr(module, tensor_name) + orig = getattr(orig_module, tensor_name) + if trained is None or orig is None: + return + trained.data = orig.data.to(trained.device) + + def apply_lottery_ticket_hypothesis(self): + r""" + Lottery ticket hypothesis algorithm (see page 2 of the paper): + + 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`). + 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`. + 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`. + 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`. + + This function implements the step 4. + + The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` + """ # noqa: E501 + + def copy_param(new, old, name: str) -> None: + dst = getattr(new, name) + src = getattr(old, name) + if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor): + return + dst.data = src.data.to(dst.device) + + for d in self._original_layers.values(): + copy, names = d["data"], d["names"] + if self._resample_parameters and hasattr(copy, "reset_parameters"): + copy = deepcopy(copy) # keep the original parameters + copy.reset_parameters() + for i, name in names: + new, new_name = self._parameters_to_prune[i] + copy_param(new, copy, name) + + def _apply_local_pruning(self, amount: float): + for module, name in self._parameters_to_prune: + self.pruning_fn(module, name=name, amount=amount) + + def _resolve_global_kwargs(self, amount: float): + self._global_kwargs["amount"] = amount + params = set(inspect.signature(self.pruning_fn).parameters) + params.discard("self") + return {k: v for k, v in self._global_kwargs.items() if k in params} + + def _apply_global_pruning(self, amount: float): + pytorch_prune.global_unstructured( + self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) + ) + + @staticmethod + def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: + attr = f"{name}_mask" + if not hasattr(module, attr): + return 0, 1 + mask = getattr(module, attr) + return (mask == 0).sum().item(), mask.numel() + + def apply_pruning(self, amount: Union[int, float]): + """ Applies pruning to ``parameters_to_prune``. """ + if self._verbose: + prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] + + if self._use_global_unstructured: + self._apply_global_pruning(amount) + else: + self._apply_local_pruning(amount) + + if self._verbose: + curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] + self._log_sparsity_stats(prev_stats, curr_stats, amount=amount) + + @rank_zero_only + def _log_sparsity_stats( + self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 + ): + total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) + prev_total_zeros = sum(zeros for zeros, _ in prev) + curr_total_zeros = sum(zeros for zeros, _ in curr) + log.info( + f"Applied `{self._pruning_fn_name}`. Pruned:" + f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->" + f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})" + ) + if self._verbose == 2: + for i, (module, name) in enumerate(self._parameters_to_prune): + prev_mask_zeros, prev_mask_size = prev[i] + curr_mask_zeros, curr_mask_size = curr[i] + log.info( + f"Applied `{self._pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:" + f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->" + f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" + ) + + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule): + parameters_to_prune = self.sanitize_parameters_to_prune( + pl_module, self._parameters_to_prune, parameter_names=self._parameter_names + ) + + self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune) + + if self._use_lottery_ticket_hypothesis: + # group modules by id. Each entry has a copy of the initial data + # and a list of the associated parameter names to prune + self._original_layers = {} + for i, (module, name) in enumerate(self._parameters_to_prune): + id_ = id(module) + self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) + self._original_layers[id_]["names"].append((i, name)) + + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs): + current_epoch = trainer.current_epoch + prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning + amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount + if not prune or not amount: + return + self.apply_pruning(amount) + + if ( + self._use_lottery_ticket_hypothesis(current_epoch) + if isinstance(self._use_lottery_ticket_hypothesis, Callable) else self._use_lottery_ticket_hypothesis + ): + self.apply_lottery_ticket_hypothesis() + + def on_train_end(self, trainer, pl_module: LightningModule): + if self._make_pruning_permanent: + rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.") + self.make_pruning_permanent(pl_module) + + def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]): + if self._make_pruning_permanent: + rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.") + prev_device = pl_module.device + # prune a copy so training can continue with the same buffers + copy = deepcopy(pl_module.to("cpu")) + self.make_pruning_permanent(copy) + checkpoint["state_dict"] = copy.state_dict() + pl_module.to(prev_device) + + @staticmethod + def sanitize_parameters_to_prune( + pl_module: LightningModule, + parameters_to_prune: Optional[_PARAM_LIST] = None, + parameter_names: Optional[List[str]] = None, + ) -> _PARAM_LIST: + """ + This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. + If ``parameters_to_prune is None``, it will be generated with all parameters of the model. + + Raises: + MisconfigurationException: + If ``parameters_to_prune`` doesn't exist in the model, or + if ``parameters_to_prune`` is neither a list of tuple nor ``None``. + """ + parameters = parameter_names or ModelPruning.PARAMETER_NAMES + + current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] + + if parameters_to_prune is None: + parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)] + elif ( + isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0 + and all(len(p) == 2 for p in parameters_to_prune) + and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune) + ): + missing_modules, missing_parameters = [], [] + for module, name in parameters_to_prune: + if module not in current_modules: + missing_modules.append(module) + continue + if not hasattr(module, name): + missing_parameters.append(name) + + if missing_modules or missing_parameters: + raise MisconfigurationException( + "Some provided `parameters_to_tune` don't exist in the model." + f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}" + ) + else: + raise MisconfigurationException( + "The provided `parameters_to_prune` should either be list of tuple" + " with 2 elements: (nn.Module, parameter_name_to_prune) or None" + ) + + return parameters_to_prune + + @staticmethod + def _is_pruning_method(method: Any) -> bool: + if not inspect.isclass(method): + return False + return issubclass(method, pytorch_prune.BasePruningMethod) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py new file mode 100644 index 00000000000000..2b6064e232da78 --- /dev/null +++ b/pytorch_lightning/callbacks/quantization.py @@ -0,0 +1,215 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Quantization +^^^^^^^^^^^^ + +""" +import functools +from typing import Any, Callable, Optional, Sequence, Union + +import torch +from torch.quantization import QConfig + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_LOWER_EQUAL_1_4 + + +def wrap_qat_forward_context( + quant_cb, + model: pl.core.LightningModule, + func: Callable, + trigger_condition: Optional[Union[Callable, int]] = None +) -> Callable: + """ + Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility + Moreover this version has the (de)quantization conditional as it may not be needed for the training all the time + """ + # todo: consider using registering hook before/after forward + @functools.wraps(func) + def wrapper(data) -> Any: + _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) + _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition + _quant_run = trigger_condition is None or _is_func_true or _is_count_true + # apply custom trigger + if _quant_run: + quant_cb._forward_calls += 1 + data = model.quant(data) + data = func(data) + # apply custom trigger + if _quant_run: + data = model.dequant(data) + return data + + return wrapper + + +def wrap_quantize_forward_context(model: pl.core.LightningModule, func: Callable) -> Callable: + """ + Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility + """ + # todo: consider using registering hook before/after forward + @functools.wraps(func) + def wrapper(data) -> Any: + data = model.quant(data) + data = func(data) + data = model.dequant(data) + return data + + return wrapper + + +def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool: + """recursive check if model has some layers denoted with '.'""" + if '.' in attribs: + attrib, attribs = attribs.split('.', 1) + if hasattr(obj, attrib): + return _recursive_hasattr(getattr(obj, attrib), attribs, state) + return False + return state and hasattr(obj, attribs) + + +class QuantizationAwareTraining(Callback): + OBSERVER_TYPES = ('histogram', 'average') + + def __init__( + self, + qconfig: Union[str, QConfig] = 'fbgemm', + observer_type: str = "average", + collect_quantization: Optional[Union[int, Callable]] = None, + modules_to_fuse: Optional[Sequence] = None, + input_compatible: bool = True, + ) -> None: + """ + Quantization allows speeding up inference and decreasing memory requirements + by performing computations and storing tensors at lower bitwidths + (such as INT8 or FLOAT16) than floating point precision. + We use native PyTorch API so for more information + see `Quantization `_. + + .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change. + + + Args: + + qconfig: quantization configuration: + + - 'fbgemm' for server inference. + - 'qnnpack' for mobile inference. + - a custom `torch.quantization.QConfig `_. + + observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default) + and ``HistogramObserver`` as "histogram" which is more computationally expensive. + + collect_quantization: count or custom function to collect quantization statistics: + + - ``None`` (deafult). The quantization observer is called in each module forward + (useful for collecting extended statistic when useing image/data augmentation). + - ``int``. Use to set a fixed number of calls, starting from the beginning. + - ``Callable``. Custom function with single trainer argument. + See this example to trigger only the last epoch: + + .. code-block:: python + + def custom_trigger_last(trainer): + return trainer.current_epoch == (trainer.max_epochs - 1) + + QuantizationAwareTraining(collect_quantization=custom_trigger_last) + + modules_to_fuse: allows you fuse a few layers together as shown in + `diagram `_ + to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286. + + input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model, + but break compatibility to torchscript. + + """ # noqa: E501 + _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines + if not isinstance(qconfig, QConfig) and not _valid_qconf_str: + raise MisconfigurationException( + f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}" + ) + self._qconfig = qconfig + + if observer_type not in self.OBSERVER_TYPES: + raise MisconfigurationException( + f'Unsupported observer type "{observer_type}", allowed are {self.OBSERVER_TYPES}.' + ) + elif observer_type == 'histogram' and _TORCH_LOWER_EQUAL_1_4: + raise MisconfigurationException(f'For using {observer_type} you need to be using pytorch>=1.5.') + self._observer_type = observer_type + + if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)): + raise MisconfigurationException( + f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.' + ) + self._collect_quantization = collect_quantization + + self.modules_to_fuse = modules_to_fuse + self._input_compatible = input_compatible + self._forward_calls = 0 + + def _check_feasible_fuse(self, model): + if not self.modules_to_fuse: + return False + for group in self.modules_to_fuse: + if not all(_recursive_hasattr(model, m) for m in group): + raise MisconfigurationException( + f'You have requested to fuse {group} but one or more of them is not your model attributes' + ) + return True + + def on_fit_start(self, trainer, pl_module): + # QuantStub converts tensors from floating point to quantized + pl_module.quant = torch.quantization.QuantStub() + # DeQuantStub converts tensors from quantized to floating point + pl_module.dequant = torch.quantization.DeQuantStub() + # manually specify where tensors will be converted from quantized + # to floating point in the quantized model + self.__module_forward = pl_module.forward + pl_module.forward = wrap_qat_forward_context( + quant_cb=self, model=pl_module, func=pl_module.forward, trigger_condition=self._collect_quantization + ) + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'fbgemm' for server inference + if isinstance(self._qconfig, str): + if self._observer_type == 'histogram': + pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig) + elif self._observer_type == 'average': + pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig) + elif isinstance(self._qconfig, QConfig): + pl_module.qconfig = self._qconfig + + if self._check_feasible_fuse(pl_module): + torch.quantization.fuse_modules(pl_module, self.modules_to_fuse, inplace=True) + + # Prepare the model for QAT. This inserts observers and fake_quants in + # the model that will observe weight and activation tensors during calibration. + torch.quantization.prepare_qat(pl_module, inplace=True) + + def on_fit_end(self, trainer, pl_module): + pl_module.eval() + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, fuses modules where appropriate, + # and replaces key operators with quantized implementations. + torch.quantization.convert(pl_module, inplace=True) + # check we shall preserve wrapper + if self._input_compatible: + pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) + else: + pl_module.forward = self.__module_forward diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py new file mode 100644 index 00000000000000..bece2ffe9f1b21 --- /dev/null +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -0,0 +1,288 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Stochastic Weight Averaging Callback +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +""" +from copy import deepcopy +from typing import Callable, Optional, Union + +import torch +from torch import nn + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _TORCH_GREATER_EQUAL_1_6: + from torch.optim.swa_utils import SWALR + +_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] + + +class StochasticWeightAveraging(Callback): + + def __init__( + self, + swa_epoch_start: Union[int, float] = 0.8, + swa_lrs: Optional[Union[float, list]] = None, + annealing_epochs: int = 10, + annealing_strategy: str = "cos", + avg_fn: Optional[_AVG_FN] = None, + device: Optional[Union[torch.device, str]] = torch.device("cpu"), + ): + r""" + + Implements the Stochastic Weight Averaging (SWA) Callback to average a model. + + Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to + Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii + Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson + (UAI 2018). + + This documentation is highly inspired by PyTorch's work on SWA. + The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. + + For a SWA explanation, please take a look + `here `_. + + .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change. + + .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. + + SWA can easily be activated directly from the Trainer as follow: + + .. code-block:: python + + Trainer(stochastic_weight_avg=True) + + Arguments: + + swa_epoch_start: If provided as int, the procedure will start from + the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, + the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch + + swa_lrs: the learning rate value for all param groups together or separately for each group. + + annealing_epochs: number of epochs in the annealing phase (default: 10) + + annealing_strategy: Specifies the annealing strategy (default: "cos"): + + - ``"cos"``. For cosine annealing. + - ``"linear"`` For linear annealing + + avg_fn: the averaging function used to update the parameters; + the function must take in the current value of the + :class:`AveragedModel` parameter, the current value of :attr:`model` + parameter and the number of models already averaged; if None, + equally weighted average is used (default: ``None``) + + device: if provided, the averaged model will be stored on the ``device``. + When None is provided, it will infer the `device` from ``pl_module``. + (default: ``"cpu"``) + + """ + + err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." + if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: + raise MisconfigurationException(err_msg) + if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): + raise MisconfigurationException(err_msg) + + wrong_type = not isinstance(swa_lrs, (float, list)) + wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 + wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)): + raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") + + if avg_fn is not None and not isinstance(avg_fn, Callable): + raise MisconfigurationException("The `avg_fn` should be callable.") + + if device is not None and not isinstance(device, (torch.device, str)): + raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + + self._swa_epoch_start = swa_epoch_start + self._swa_lrs = swa_lrs + self._annealing_epochs = annealing_epochs + self._annealing_strategy = annealing_strategy + self._avg_fn = avg_fn or self.avg_fn + self._device = device + self._model_contains_batch_norm = None + self._average_model = None + + @property + def swa_start(self) -> int: + return max(self._swa_epoch_start - 1, 0) # 0-based + + @property + def swa_end(self) -> int: + return self._max_epochs - 1 # 0-based + + @staticmethod + def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'): + return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) + + def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + # copy the model before moving it to accelerator device. + self._average_model = deepcopy(pl_module) + + def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + optimizers = trainer.optimizers + lr_schedulers = trainer.lr_schedulers + + if len(optimizers) != 1: + raise MisconfigurationException("SWA currently works with 1 `optimizer`.") + + if len(lr_schedulers) > 1: + raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + + if isinstance(self._swa_epoch_start, float): + self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) + + self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) + + self._max_epochs = trainer.max_epochs + if self._model_contains_batch_norm: + # virtually increase max_epochs to perform batch norm update on latest epoch. + trainer.max_epochs += 1 + + def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + if trainer.current_epoch == self.swa_start: + # move average model to request device. + self._average_model = self._average_model.to(self._device or pl_module.device) + + optimizers = trainer.optimizers + + for param_group in optimizers[0].param_groups: + if self._swa_lrs is None: + initial_lr = param_group["lr"] + + elif isinstance(self._swa_lrs, float): + initial_lr = self._swa_lrs + + else: + initial_lr = self._swa_lrs[0] + + param_group["initial_lr"] = initial_lr + + self._swa_lrs = initial_lr + + self._swa_scheduler = SWALR( + optimizers[0], + swa_lr=initial_lr, + anneal_epochs=self._annealing_epochs, + anneal_strategy=self._annealing_strategy, + last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1 + ) + + if trainer.lr_schedulers: + lr_scheduler = trainer.lr_schedulers[0]["scheduler"] + rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}") + trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler + else: + _scheduler_config = _get_default_scheduler_config() + _scheduler_config["scheduler"] = self._swa_scheduler + trainer.lr_schedulers.append(_scheduler_config) + + self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) + + if self.swa_start <= trainer.current_epoch <= self.swa_end: + self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) + + # Note: No > here in case the callback is saved with the model and training continues + if trainer.current_epoch == self.swa_end + 1: + + # Transfer weights from average model to pl_module + self.transfer_weights(self._average_model, pl_module) + + # Reset BatchNorm for update + self.reset_batch_norm_and_save_state(pl_module) + + # There is no need to perform either backward or optimizer.step as we are + # performing only one pass over the train data-loader to compute activation statistics + # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. + trainer.num_training_batches += 1 + trainer.train_loop._skip_backward = True + self._accumulate_grad_batches = trainer.accumulate_grad_batches + trainer.accumulate_grad_batches = len(trainer.train_dataloader) + + def on_train_epoch_end(self, trainer: 'pl.Trainer', *args): + trainer.train_loop._skip_backward = False + + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): + if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: + # BatchNorm epoch update. Reset state + trainer.accumulate_grad_batches = self._accumulate_grad_batches + trainer.num_training_batches -= 1 + trainer.max_epochs -= 1 + self.reset_momenta() + elif trainer.current_epoch == self.swa_end: + # Last SWA epoch. Transfer weights from average model to pl_module + self.transfer_weights(self._average_model, pl_module) + + @staticmethod + def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'): + for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): + dst_param.detach().copy_(src_param.to(dst_param.device)) + + def reset_batch_norm_and_save_state(self, pl_module: 'pl.LightningModule'): + """ + Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154 + """ + self.momenta = {} + for module in pl_module.modules(): + if not isinstance(module, nn.modules.batchnorm._BatchNorm): + continue + module.running_mean = torch.zeros_like( + module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype + ) + module.running_var = torch.ones_like( + module.running_var, device=pl_module.device, dtype=module.running_var.dtype + ) + self.momenta[module] = module.momentum + module.momentum = None + module.num_batches_tracked *= 0 + + def reset_momenta(self): + """ + Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165 + """ + for bn_module in self.momenta.keys(): + bn_module.momentum = self.momenta[bn_module] + + @staticmethod + def update_parameters( + average_model: 'pl.LightningModule', model: 'pl.LightningModule', n_averaged: torch.LongTensor, avg_fn: _AVG_FN + ): + """ + Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112 + """ + for p_swa, p_model in zip(average_model.parameters(), model.parameters()): + device = p_swa.device + p_swa_ = p_swa.detach() + p_model_ = p_model.detach().to(device) + src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) + p_swa_.copy_(src) + n_averaged += 1 + + @staticmethod + def avg_fn( + averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor + ) -> torch.FloatTensor: + """ + Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97 + """ + return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) diff --git a/pytorch_lightning/core/__init__.py b/pytorch_lightning/core/__init__.py index 83fff7d862fe76..bcab67d821e098 100644 --- a/pytorch_lightning/core/__init__.py +++ b/pytorch_lightning/core/__init__.py @@ -1,339 +1,22 @@ -""" -A :class:`~LightningModule` organizes your PyTorch code into the following sections: - -.. figure:: /_images/lightning_module/pt_to_pl.png - :alt: Convert from PyTorch to Lightning - - -Notice a few things. - -1. It's the SAME code. -2. The PyTorch code IS NOT abstracted - just organized. -3. All the other code that's not in the :class:`~LightningModule` - has been automated for you by the trainer. - - .. code-block:: python - - net = Net() - trainer = Trainer() - trainer.fit(net) - -4. There are no .cuda() or .to() calls... Lightning does these for you. - - .. code-block:: python - - # don't do in lightning - x = torch.Tensor(2, 3) - x = x.cuda() - x = x.to(device) - - # do this instead - x = x # leave it alone! - - # or to init a new tensor - new_x = torch.Tensor(2, 3) - new_x = new_x.type_as(x.type()) - -5. There are no samplers for distributed, Lightning also does this for you. - - .. code-block:: python - - # Don't do in Lightning... - data = MNIST(...) - sampler = DistributedSampler(data) - DataLoader(data, sampler=sampler) - - # do this instead - data = MNIST(...) - DataLoader(data) - -6. A :class:`~LightningModule` is a :class:`torch.nn.Module` but with added functionality. Use it as such! - - .. code-block:: python - - net = Net.load_from_checkpoint(PATH) - net.freeze() - out = net(x) - -Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, -(and let's be real, you probably should do anyhow). - ------------- - -Minimal Example ---------------- - -Here are the only required methods. - -.. code-block:: python - - >>> import pytorch_lightning as pl - >>> class LitModel(pl.LightningModule): - ... - ... def __init__(self): - ... super().__init__() - ... self.l1 = torch.nn.Linear(28 * 28, 10) - ... - ... def forward(self, x): - ... return torch.relu(self.l1(x.view(x.size(0), -1))) - ... - ... def training_step(self, batch, batch_idx): - ... x, y = batch - ... y_hat = self(x) - ... return {'loss': F.cross_entropy(y_hat, y)} - ... - ... def train_dataloader(self): - ... return DataLoader(MNIST(os.getcwd(), train=True, download=True, - ... transform=transforms.ToTensor()), batch_size=32) - ... - ... def configure_optimizers(self): - ... return torch.optim.Adam(self.parameters(), lr=0.02) - -Which you can train by doing: - -.. code-block:: python - - trainer = pl.Trainer() - model = LitModel() - - trainer.fit(model) - ----------- - -Training loop structure ------------------------ - -The general pattern is that each loop (training, validation, test loop) -has 3 methods: - -- ``___step`` -- ``___step_end`` -- ``___epoch_end`` - -To show how Lightning calls these, let's use the validation loop as an example: - -.. code-block:: python - - val_outs = [] - for val_batch in val_data: - # do something with each batch - out = validation_step(val_batch) - val_outs.append(out) - - # do something with the outputs for all batches - # like calculate validation set accuracy or loss - validation_epoch_end(val_outs) - -If we use dp or ddp2 mode, we can also define the ``XXX_step_end`` method to operate -on all parts of the batch:: - - val_outs = [] - for val_batch in val_data: - batches = split_batch(val_batch) - dp_outs = [] - for sub_batch in batches: - dp_out = validation_step(sub_batch) - dp_outs.append(dp_out) - - out = validation_step_end(dp_outs) - val_outs.append(out) - - # do something with the outputs for all batches - # like calculate validation set accuracy or loss - validation_epoch_end(val_outs) - - -Add validation loop -^^^^^^^^^^^^^^^^^^^ - -Thus, if we wanted to add a validation loop you would add this to your -:class:`~LightningModule`: - - >>> class LitModel(pl.LightningModule): - ... def validation_step(self, batch, batch_idx): - ... x, y = batch - ... y_hat = self(x) - ... return {'val_loss': F.cross_entropy(y_hat, y)} - ... - ... def validation_epoch_end(self, outputs): - ... val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() - ... return {'val_loss': val_loss_mean} - ... - ... def val_dataloader(self): - ... # can also return a list of val dataloaders - ... return DataLoader(...) - -Add test loop -^^^^^^^^^^^^^ - - >>> class LitModel(pl.LightningModule): - ... def test_step(self, batch, batch_idx): - ... x, y = batch - ... y_hat = self(x) - ... return {'test_loss': F.cross_entropy(y_hat, y)} - ... - ... def test_epoch_end(self, outputs): - ... test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() - ... return {'test_loss': test_loss_mean} - ... - ... def test_dataloader(self): - ... # can also return a list of test dataloaders - ... return DataLoader(...) - -However, the test loop won't ever be called automatically to make sure you -don't run your test data by accident. Instead you have to explicitly call: - -.. code-block:: python - - # call after training - trainer = Trainer() - trainer.fit(model) - trainer.test() - - # or call with pretrained model - model = MyLightningModule.load_from_checkpoint(PATH) - trainer = Trainer() - trainer.test(model) - ----------- - -Training_step_end method ------------------------- -When using :class:`~pytorch_lightning.overrides.data_parallel.LightningDataParallel` or -:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel`, the -:meth:`~LightningModule.training_step` -will be operating on a portion of the batch. This is normally ok but in special -cases like calculating NCE loss using negative samples, we might want to -perform a softmax across all samples in the batch. - -For these types of situations, each loop has an additional ``__step_end`` method -which allows you to operate on the pieces of the batch: - -.. code-block:: python - - training_outs = [] - for train_batch in train_data: - # dp, ddp2 splits the batch - sub_batches = split_batches_for_dp(batch) - - # run training_step on each piece of the batch - batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches] - - # do softmax with all pieces - out = training_step_end(batch_parts_outputs) - training_outs.append(out) - - # do something with the outputs for all batches - # like calculate validation set accuracy or loss - training_epoch_end(val_outs) - ----------- - -Remove cuda calls ------------------ -In a :class:`~LightningModule`, all calls to ``.cuda()`` -and ``.to(device)`` should be removed. Lightning will do these -automatically. This will allow your code to work on CPUs, TPUs and GPUs. - -When you init a new tensor in your code, just use :meth:`~torch.Tensor.type_as`: - -.. code-block:: python - - def training_step(self, batch, batch_idx): - x, y = batch - - # put the z on the appropriate gpu or tpu core - z = sample_noise() - z = z.type_as(x) - ----------- - -Data preparation ----------------- -Data preparation in PyTorch follows 5 steps: - -1. Download -2. Clean and (maybe) save to disk -3. Load inside :class:`~torch.utils.data.Dataset` -4. Apply transforms (rotate, tokenize, etc...) -5. Wrap inside a :class:`~torch.utils.data.DataLoader` - -When working in distributed settings, steps 1 and 2 have to be done -from a single GPU, otherwise you will overwrite these files from -every GPU. The :class:`~LightningModule` has the -:class:`~LightningModule.prepare_data` method to -allow for this: - - >>> class LitModel(pl.LightningModule): - ... def prepare_data(self): - ... # download - ... mnist_train = MNIST(os.getcwd(), train=True, download=True, - ... transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=True, - ... transform=transforms.ToTensor()) - ... - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test - ... - ... def train_dataloader(self): - ... return DataLoader(self.train_dataset, batch_size=64) - ... - ... def val_dataloader(self): - ... return DataLoader(self.mnist_val, batch_size=64) - ... - ... def test_dataloader(self): - ... return DataLoader(self.mnist_test, batch_size=64) - -Note: - :meth:`~LightningModule.prepare_data` is called once. - -Note: - Do anything with data that needs to happen ONLY once here, like download, tokenize, etc... - - -Lifecycle ---------- -The methods in the :class:`~LightningModule` are called in this order: - -1. :meth:`~LightningModule.__init__` -2. :meth:`~LightningModule.prepare_data` -3. :meth:`~LightningModule.configure_optimizers` -4. :meth:`~LightningModule.train_dataloader` - -If you define a validation loop then - -5. :meth:`~LightningModule.val_dataloader` - -And if you define a test loop: - -6. :meth:`~LightningModule.test_dataloader` - -Note: - :meth:`~LightningModule.test_dataloader` is only called with ``.test()`` - -In every epoch, the loop methods are called in this frequency: - -1. :meth:`~LightningModule.validation_step` called every batch -2. :meth:`~LightningModule.validation_epoch_end` called every epoch - -Live demo ---------- -Check out this -`COLAB `_ -for a live demo. - -LightningModule Class ---------------------- - -""" - -from pytorch_lightning.core.decorators import data_loader +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -__all__ = ['LightningModule', 'data_loader'] +__all__ = [ + 'LightningDataModule', + 'LightningModule', +] # __call__ = __all__ diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py new file mode 100644 index 00000000000000..4178c9eeacd503 --- /dev/null +++ b/pytorch_lightning/core/datamodule.py @@ -0,0 +1,398 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LightningDataModule for loading DataLoaders with ease.""" + +import functools +from argparse import ArgumentParser, Namespace +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union + +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types + + +class _DataModuleWrapper(type): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__has_added_checks = False + + def __call__(cls, *args, **kwargs): + """A wrapper for LightningDataModule that: + + 1. Runs user defined subclass's __init__ + 2. Assures prepare_data() runs on rank 0 + 3. Lets you check prepare_data and setup to see if they've been called + """ + if not cls.__has_added_checks: + cls.__has_added_checks = True + # Track prepare_data calls and make sure it runs on rank zero + cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) + # Track setup calls + cls.setup = track_data_hook_calls(cls.setup) + # Track teardown calls + cls.teardown = track_data_hook_calls(cls.teardown) + + # Get instance of LightningDataModule by mocking its __init__ via __call__ + obj = type.__call__(cls, *args, **kwargs) + + return obj + + +def track_data_hook_calls(fn): + """A decorator that checks if prepare_data/setup/teardown has been called. + + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` + + Args: + fn (function): Function that will be tracked to see if it has been called. + + Returns: + function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. + """ + + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + + # The object instance from which setup or prepare_data was called + obj = args[0] + name = fn.__name__ + + # If calling setup, we check the stage and assign stage-specific bool args + if name in ("setup", "teardown"): + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() + stage = args[1] if len(args) > 1 else kwargs.get("stage", None) + + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_{name}_{s}", True) + else: + setattr(obj, f"_has_{name}_{stage}", True) + + elif name == "prepare_data": + obj._has_prepared_data = True + + return fn(*args, **kwargs) + + return wrapped_fn + + +class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): + """ + A DataModule standardizes the training, val, test splits, data preparation and transforms. + The main advantage is consistent data splits, data preparation and transforms across models. + + Example:: + + class MyDataModule(LightningDataModule): + def __init__(self): + super().__init__() + def prepare_data(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + def setup(self): + # make assignments here (val/train/test split) + # called on every process in DDP + def train_dataloader(self): + train_split = Dataset(...) + return DataLoader(train_split) + def val_dataloader(self): + val_split = Dataset(...) + return DataLoader(val_split) + def test_dataloader(self): + test_split = Dataset(...) + return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP + + A DataModule implements 6 key methods: + + * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). + * **setup** (things to do on every accelerator in distributed mode). + * **train_dataloader** the training dataloader. + * **val_dataloader** the val dataloader(s). + * **test_dataloader** the test dataloader(s). + * **teardown** (things to do on every accelerator in distributed mode when finished) + + + This allows you to share a full dataset without explaining how to download, + split transform and process the data + + """ + + name: str = ... + + def __init__( + self, + train_transforms=None, + val_transforms=None, + test_transforms=None, + dims=None, + ): + super().__init__() + self._train_transforms = train_transforms + self._val_transforms = val_transforms + self._test_transforms = test_transforms + self._dims = dims if dims is not None else () + + # Pointer to the trainer object + self.trainer = None + + # Private attrs to keep track of whether or not data hooks have been called yet + self._has_prepared_data = False + + self._has_setup_fit = False + self._has_setup_validate = False + self._has_setup_test = False + self._has_setup_predict = False + + self._has_teardown_fit = False + self._has_teardown_validate = False + self._has_teardown_test = False + self._has_teardown_predict = False + + @property + def train_transforms(self): + """ + Optional transforms (or collection of transforms) you can apply to train dataset + """ + return self._train_transforms + + @train_transforms.setter + def train_transforms(self, t): + self._train_transforms = t + + @property + def val_transforms(self): + """ + Optional transforms (or collection of transforms) you can apply to validation dataset + """ + return self._val_transforms + + @val_transforms.setter + def val_transforms(self, t): + self._val_transforms = t + + @property + def test_transforms(self): + """ + Optional transforms (or collection of transforms) you can apply to test dataset + """ + return self._test_transforms + + @test_transforms.setter + def test_transforms(self, t): + self._test_transforms = t + + @property + def dims(self): + """ + A tuple describing the shape of your data. Extra functionality exposed in ``size``. + """ + return self._dims + + @dims.setter + def dims(self, d): + self._dims = d + + def size(self, dim=None) -> Union[Tuple, int]: + """ + Return the dimension of each input either as a tuple or list of tuples. You can index this + just as you would with a torch tensor. + """ + + if dim is not None: + return self.dims[dim] + + return self.dims + + @property + def has_prepared_data(self) -> bool: + """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not. + + Returns: + bool: True if ``datamodule.prepare_data()`` has been called. False by default. + """ + return self._has_prepared_data + + @property + def has_setup_fit(self) -> bool: + """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not. + + Returns: + bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. + """ + return self._has_setup_fit + + @property + def has_setup_validate(self) -> bool: + """Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. + """ + return self._has_setup_validate + + @property + def has_setup_test(self) -> bool: + """Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. + """ + return self._has_setup_test + + @property + def has_setup_predict(self) -> bool: + """Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. + """ + return self._has_setup_predict + + @property + def has_teardown_fit(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. + + Returns: + bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + """ + return self._has_teardown_fit + + @property + def has_teardown_validate(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + """ + return self._has_teardown_validate + + @property + def has_teardown_test(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + """ + return self._has_teardown_test + + @property + def has_teardown_predict(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + """ + return self._has_teardown_predict + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + """Extends existing argparse by default `LightningDataModule` attributes.""" + return add_argparse_args(cls, parent_parser, **kwargs) + + @classmethod + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + """Create an instance from CLI arguments. + + Args: + args: The parser or namespace to take arguments from. Only known arguments will be + parsed and passed to the :class:`LightningDataModule`. + **kwargs: Additional keyword arguments that may override ones in the parser or namespace. + These must be valid DataModule arguments. + + Example:: + parser = ArgumentParser(add_help=False) + parser = LightningDataModule.add_argparse_args(parser) + module = LightningDataModule.from_argparse_args(args) + """ + return from_argparse_args(cls, args, **kwargs) + + @classmethod + def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: + r"""Scans the DataModule signature and returns argument names, types and default values. + + Returns: + List with tuples of 3 values: + (argument name, set with argument types, argument default value). + """ + return get_init_arguments_and_types(cls) + + @classmethod + def from_datasets( + cls, + train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None, + val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + batch_size: int = 1, + num_workers: int = 0, + ): + r""" + Create an instance from torch.utils.data.Dataset. + + Args: + train_dataset: (optional) Dataset to be used for train_dataloader() + val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() + test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() + batch_size: Batch size to use for each dataloader. Default is 1. + num_workers: Number of subprocesses to use for data loading. 0 means that the + data will be loaded in the main process. Number of CPUs available. + + """ + + def dataloader(ds, shuffle=False): + return DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, + ) + + def train_dataloader(): + if isinstance(train_dataset, Mapping): + return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} + if isinstance(train_dataset, Sequence): + return [dataloader(ds, shuffle=True) for ds in train_dataset] + return dataloader(train_dataset, shuffle=True) + + def val_dataloader(): + if isinstance(val_dataset, Sequence): + return [dataloader(ds) for ds in val_dataset] + return dataloader(val_dataset) + + def test_dataloader(): + if isinstance(test_dataset, Sequence): + return [dataloader(ds) for ds in test_dataset] + return dataloader(test_dataset) + + datamodule = cls() + if train_dataset is not None: + datamodule.train_dataloader = train_dataloader + if val_dataset is not None: + datamodule.val_dataloader = val_dataloader + if test_dataset is not None: + datamodule.test_dataloader = test_dataloader + return datamodule diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 3979a4fc6f7ee4..5def3c0caa445b 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,14 +1,103 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decorator for LightningModule methods.""" + +from functools import wraps +from typing import Callable + from pytorch_lightning.utilities import rank_zero_warn -def data_loader(fn): - """Decorator to make any fx with this use the lazy property. +def auto_move_data(fn: Callable) -> Callable: + """ + Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which + input arguments should be moved automatically to the correct device. + It as no effect if applied to a method of an object that is not an instance of + :class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__`` + or ``forward``. + + Args: + fn: A LightningModule method for which the arguments should be moved to the device + the parameters are on. + + Example:: + + # directly in the source code + class LitModel(LightningModule): + + @auto_move_data + def forward(self, x): + return x + + # or outside + LitModel.forward = auto_move_data(LitModel.forward) + + model = LitModel() + model = model.to('cuda') + model(torch.zeros(1, 3)) + + # input gets moved to device + # tensor([[0., 0., 0.]], device='cuda:0') - Warnings: - This decorator deprecated in v0.7.0 and it will be removed v0.9.0. """ - rank_zero_warn('`data_loader` decorator deprecated in v0.7.0. Will be removed v0.9.0', DeprecationWarning) - def inner_fx(self): - return fn(self) - return inner_fx + @wraps(fn) + def auto_transfer_args(self, *args, **kwargs): + from pytorch_lightning.core.lightning import LightningModule + if not isinstance(self, LightningModule): + return fn(self, *args, **kwargs) + + args, kwargs = self.transfer_batch_to_device((args, kwargs)) + return fn(self, *args, **kwargs) + + return auto_transfer_args + + +def parameter_validation(fn: Callable) -> Callable: + """ + Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method. + Validates that the module parameter lengths match after moving to the device. It is useful + when tying weights on TPU's. + + Args: + fn: ``.to`` method + + Note: + TPU's require weights to be tied/shared after moving the module to the device. + Failure to do this results in the initialization of new weights which are not tied. + To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook + which is called after the module has been moved to the device. + + See Also: + - `XLA Documentation `_ + """ + + @wraps(fn) + def inner_fn(self, *args, **kwargs): + pre_layer_count = len(list(self.parameters())) + module = fn(self, *args, **kwargs) + self.on_post_move_to_device() + post_layer_count = len(list(self.parameters())) + + if not pre_layer_count == post_layer_count: + rank_zero_warn( + f'The model layers do not match after moving to the target device.' + ' If your model employs weight sharing on TPU,' + ' please tie your weights using the `on_post_move_to_device` model hook.\n' + f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]' + ) + + return module + + return inner_fn diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py index b5d2d5616a60fb..21598fcba0a427 100644 --- a/pytorch_lightning/core/grads.py +++ b/pytorch_lightning/core/grads.py @@ -1,30 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Module to describe gradients """ -from typing import Dict +from typing import Dict, Union -from torch import nn +import torch +from torch.nn import Module -class GradInformation(nn.Module): +class GradInformation(Module): - def grad_norm(self, norm_type: float) -> Dict[str, int]: - results = {} - total_norm = 0 + def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]: + """Compute each parameter's gradient's norm and their overall norm. + + The overall norm is computed over all gradients together, as if they + were concatenated into a single vector. + + Args: + norm_type: The type of the used p-norm, cast to float if necessary. + Can be ``'inf'`` for infinity norm. + + Return: + norms: The dictionary of p-norms of each parameter's gradient and + a special entry for the total p-norm of the gradients viewed + as a single vector. + """ + norm_type = float(norm_type) + + norms, all_norms = {}, [] for name, p in self.named_parameters(): - if p.requires_grad: - try: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm ** norm_type - norm = param_norm ** (1 / norm_type) - - grad = round(norm.data.cpu().numpy().flatten()[0], 3) - results['grad_{}_norm_{}'.format(norm_type, name)] = grad - except Exception: - # this param had no grad - pass - - total_norm = total_norm ** (1. / norm_type) - grad = round(total_norm.data.cpu().numpy().flatten()[0], 3) - results['grad_{}_norm_total'.format(norm_type)] = grad - return results + if p.grad is None: + continue + + param_norm = float(p.grad.data.norm(norm_type)) + norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 4) + + all_norms.append(param_norm) + + total_norm = float(torch.tensor(all_norms).norm(norm_type)) + norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 4) + + return norms diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1a3f05be11c507..b320a9b2238406 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -1,31 +1,45 @@ -from typing import Any +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Various hooks to be used in the Lightning code.""" + +from typing import Any, Dict, List, Optional, Union import torch -from torch import Tensor from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True +from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn -class ModelHooks(torch.nn.Module): +class ModelHooks: + """Hooks to be used in LightningModule.""" - # TODO: remove in v0.9.0 - def on_sanity_check_start(self): + def on_fit_start(self) -> None: + """ + Called at the very beginning of fit. + If on DDP it is called on every process """ - Called before starting evaluation. - Warning: - Deprecated. Will be removed in v0.9.0. + def on_fit_end(self) -> None: + """ + Called at the very end of fit. + If on DDP it is called on every process """ def on_train_start(self) -> None: """ - Called at the beginning of training before sanity check. + Called at the beginning of training after sanity check. """ # do something at the start of training @@ -35,7 +49,43 @@ def on_train_end(self) -> None: """ # do something at the end of training - def on_batch_start(self, batch: Any) -> None: + def on_validation_start(self) -> None: + """ + Called at the beginning of validation. + """ + # do something at the start of validation + + def on_validation_end(self) -> None: + """ + Called at the end of validation. + """ + # do something at the end of validation + + def on_pretrain_routine_start(self) -> None: + """ + Called at the beginning of the pretrain routine (between fit and train start). + + - fit + - pretrain_routine start + - pretrain_routine end + - training_start + + """ + # do something at the start of the pretrain routine + + def on_pretrain_routine_end(self) -> None: + """ + Called at the end of the pretrain routine (between fit and train start). + + - fit + - pretrain_routine start + - pretrain_routine end + - training_start + + """ + # do something at the end of the pretrain routine + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop before anything happens for that batch. @@ -43,38 +93,170 @@ def on_batch_start(self, batch: Any) -> None: Args: batch: The batched data as it is returned by the training DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader """ # do something when the batch starts - def on_batch_end(self) -> None: + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop after the batch. + + Args: + outputs: The outputs of training_step_end(training_step(x)) + batch: The batched data as it is returned by the training DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader """ # do something when the batch ends + def on_validation_model_eval(self) -> None: + """ + Sets the model to eval during the val loop + """ + self.eval() + + def on_validation_model_train(self) -> None: + """ + Sets the model to train during the val loop + """ + self.train() + + def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Called in the validation loop before anything happens for that batch. + + Args: + batch: The batched data as it is returned by the validation DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + # do something when the batch starts + + def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Called in the validation loop after the batch. + + Args: + outputs: The outputs of validation_step_end(validation_step(x)) + batch: The batched data as it is returned by the validation DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + # do something when the batch ends + + def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Called in the test loop before anything happens for that batch. + + Args: + batch: The batched data as it is returned by the test DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + # do something when the batch starts + + def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Called in the test loop after the batch. + + Args: + outputs: The outputs of test_step_end(test_step(x)) + batch: The batched data as it is returned by the test DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader + """ + # do something when the batch ends + + def on_test_model_train(self) -> None: + """ + Sets the model to train during the test loop + """ + self.train() + + def on_test_model_eval(self) -> None: + """ + Sets the model to eval during the test loop + """ + self.eval() + + def on_predict_model_eval(self) -> None: + """ + Sets the model to eval during the predict loop + """ + self.eval() + def on_epoch_start(self) -> None: """ - Called in the training loop at the very beginning of the epoch. + Called when either of train/val/test epoch begins. """ # do something when the epoch starts def on_epoch_end(self) -> None: + """ + Called when either of train/val/test epoch ends. + """ + # do something when the epoch ends + + def on_train_epoch_start(self) -> None: + """ + Called in the training loop at the very beginning of the epoch. + """ + # do something when the epoch starts + + def on_train_epoch_end(self, outputs: List[Any]) -> None: """ Called in the training loop at the very end of the epoch. """ # do something when the epoch ends - def on_pre_performance_check(self) -> None: + def on_validation_epoch_start(self) -> None: + """ + Called in the validation loop at the very beginning of the epoch. + """ + # do something when the epoch starts + + def on_validation_epoch_end(self, outputs: List[Any]) -> None: + """ + Called in the validation loop at the very end of the epoch. + """ + # do something when the epoch ends + + def on_test_epoch_start(self) -> None: + """ + Called in the test loop at the very beginning of the epoch. + """ + # do something when the epoch starts + + def on_test_epoch_end(self, outputs: List[Any]) -> None: + """ + Called in the test loop at the very end of the epoch. + """ + # do something when the epoch ends + + def on_test_start(self) -> None: + """ + Called at the beginning of testing. + """ + # do something at the start of testing + + def on_test_end(self) -> None: + """ + Called at the end of testing. + """ + # do something at the end of testing + + def on_predict_start(self) -> None: """ - Called at the very beginning of the validation loop. + Called at the beginning of predicting. """ - # do something before validation starts + # do something at the start of predicting - def on_post_performance_check(self) -> None: + def on_predict_end(self) -> None: """ - Called at the very end of the validation loop. + Called at the end of predicting. """ - # do something before validation end + # do something at the end of predicting def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ @@ -88,7 +270,7 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: for optimizer in optimizers: optimizer.step() model.on_before_zero_grad(optimizer) # < ---- called here - optimizer.zero_grad + optimizer.zero_grad() Args: optimizer: The optimizer for which grads should be zeroed. @@ -105,51 +287,509 @@ def on_after_backward(self) -> None: def on_after_backward(self): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge - params = self.state_dict() - for k, v in params.items(): - grads = v - name = k - self.logger.experiment.add_histogram(tag=name, values=grads, - global_step=self.trainer.global_step) + for k, v in self.named_parameters(): + self.logger.experiment.add_histogram( + tag=k, values=v.grad, global_step=self.trainer.global_step + ) + + """ + + def on_post_move_to_device(self) -> None: + """ + Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to` + is called. This is a good place to tie weights between modules after moving them to a device. Can be + used when training models with weight sharing properties on TPU. + + Addresses the handling of shared weights on TPU: + https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks + + Example:: + + def on_post_move_to_device(self): + self.decoder.weight = self.encoder.weight + + """ + + def configure_sharded_model(self) -> None: + """ + Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, + where we'd like to shard the model instantly, which is useful for extremely large models + which can save memory and initialization time. + + The accelerator manages whether to call this hook at every given stage. + For sharded plugins where model parallelism is required, the hook is usually on called once + to initialize the sharded parameters, and not called again in the same process. + + By default for accelerators/plugins that do not use model sharding techniques, + this hook is called during each fit/val/test/predict stages. + """ + + +class DataHooks: + """Hooks to be used for data related stuff.""" + + def prepare_data(self) -> None: + """ + Use this to download and prepare data. + + .. warning:: DO NOT set state to the model (use `setup` instead) + since this is NOT called on every GPU in DDP/TPU + + Example:: + + def prepare_data(self): + # good + download_data() + tokenize() + etc() + + # bad + self.split = data_split + self.some_state = some_other_state() + + In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)): + + 1. Once per node. This is the default and is only called on LOCAL_RANK=0. + 2. Once in total. Only called on GLOBAL_RANK=0. + + Example:: + + # DEFAULT + # called once per node on LOCAL_RANK=0 of that node + Trainer(prepare_data_per_node=True) + # call on GLOBAL_RANK=0 (great for shared file systems) + Trainer(prepare_data_per_node=False) + + This is called before requesting the dataloaders: + + .. code-block:: python + + model.prepare_data() + if ddp/tpu: init() + model.setup(stage) + model.train_dataloader() + model.val_dataloader() + model.test_dataloader() """ - def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: + def setup(self, stage: Optional[str] = None) -> None: """ - Override backward with your own implementation if you need to. + Called at the beginning of fit (train + validate), validate, test, predict, or tune. + This is a good hook when you need to build models dynamically or adjust something about them. + This hook is called on every process when using DDP. Args: - trainer: Pointer to the trainer - loss: Loss is already scaled by accumulated grads - optimizer: Current optimizer being used - optimizer_idx: Index of the current optimizer being used + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else - Called to perform backward step. - Feel free to override as needed. + def setup(stage): + data = Load_data(...) + self.l1 = nn.Linear(28, data.num_classes) - The loss passed in has already been scaled for accumulated gradients if requested. + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Called at the end of fit (train + validate), validate, test, predict, or tune. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + + def train_dataloader(self) -> Any: + """ + Implement one or more PyTorch DataLoaders for training. + + Return: + Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see + this :ref:`page ` + + The dataloader you return will not be called every epoch unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + + For data processing use the following pattern: + + - download in :meth:`prepare_data` + - process and split in :meth:`setup` + + However, the above are only necessary for distributed processing. + + .. warning:: do not assign state in prepare_data + + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`setup` + - :meth:`train_dataloader` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. Example:: - def backward(self, use_amp, loss, optimizer): - if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + # single dataloader + def train_dataloader(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, + download=True) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=True + ) + return loader + + # multiple dataloaders, return as list + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a list of tensors: [batch_mnist, batch_cifar] + return [mnist_loader, cifar_loader] + + # multiple dataloader, return as dict + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} + return {'mnist': mnist_loader, 'cifar': cifar_loader} + + """ + rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") + + def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement one or multiple PyTorch DataLoaders for testing. + + The dataloader you return will not be called every epoch unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + + For data processing use the following pattern: + + - download in :meth:`prepare_data` + - process and split in :meth:`setup` + + However, the above are only necessary for distributed processing. + + .. warning:: do not assign state in prepare_data + + + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`setup` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. + + Return: + Single or multiple PyTorch DataLoaders. + + Example:: + + def test_dataloader(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, + download=True) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=False + ) + + return loader + + # can also return multiple dataloaders + def test_dataloader(self): + return [loader_a, loader_b, ..., loader_n] + + Note: + If you don't need a test dataset and a :meth:`test_step`, you don't need to implement + this method. + + Note: + In the case where you return multiple test dataloaders, the :meth:`test_step` + will have an argument ``dataloader_idx`` which matches the order here. + """ + + def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement one or multiple PyTorch DataLoaders for validation. + + The dataloader you return will not be called every epoch unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. + + Return: + Single or multiple PyTorch DataLoaders. + + Examples:: + + def val_dataloader(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root='/path/to/mnist/', train=False, + transform=transform, download=True) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=False + ) + + return loader + + # can also return multiple dataloaders + def val_dataloader(self): + return [loader_a, loader_b, ..., loader_n] + + Note: + If you don't need a validation dataset and a :meth:`validation_step`, you don't need to + implement this method. + + Note: + In the case where you return multiple validation dataloaders, the :meth:`validation_step` + will have an argument ``dataloader_idx`` which matches the order here. + """ + + def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement one or multiple PyTorch DataLoaders for prediction. + + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` + + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. + + Return: + Single or multiple PyTorch DataLoaders. + + Note: + In the case where you return multiple prediction dataloaders, the :meth:`predict` + will have an argument ``dataloader_idx`` which matches the order here. + """ + + def on_train_dataloader(self) -> None: + """Called before requesting the train dataloader.""" + + def on_val_dataloader(self) -> None: + """Called before requesting the val dataloader.""" + + def on_test_dataloader(self) -> None: + """Called before requesting the test dataloader.""" + + def on_predict_dataloader(self) -> None: + """Called before requesting the predict dataloader.""" + + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: + """ + Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors + wrapped in a custom data structure. + + The data types listed below (and any arbitrary nesting of them) are supported out of the box: + + - :class:`torch.Tensor` or anything that implements `.to(...)` + - :class:`list` + - :class:`dict` + - :class:`tuple` + - :class:`torchtext.data.batch.Batch` + + For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument (unless you know what you are doing). + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + + Returns: + A reference to the data on the new device. + + Example:: + + def transfer_batch_to_device(self, batch, device): + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) else: - loss.backward() + batch = super().transfer_batch_to_device(data, device) + return batch + + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + + See Also: + - :meth:`move_data_to_device` + - :meth:`apply_to_collection` + """ + device = device or self.device + return move_data_to_device(batch, device) + + def on_before_batch_transfer(self, batch, dataloader_idx): + """ + Override to alter or apply batch augmentations to your batch before it is transferred to the device. + + .. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future. + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: DataLoader idx for batch + + Returns: + A batch of data + + Example:: + + def on_before_batch_transfer(self, batch, dataloader_idx): + batch['x'] = transforms(batch['x']) + return batch + + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: + - :meth:`on_after_batch_transfer` + - :meth:`transfer_batch_to_device` """ - if trainer.precision == 16: - # .backward is not special on 16-bit with TPUs - if trainer.on_tpu: - return + return batch - if self.trainer.use_native_amp: - self.trainer.scaler.scale(loss).backward() + def on_after_batch_transfer(self, batch, dataloader_idx): + """ + Override to alter or apply batch augmentations to your batch after it is transferred to the device. + + .. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true ``idx`` in the future. + + Note: + This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: DataLoader idx for batch (Default: 0) + + Returns: + A batch of data + + Example:: + + def on_after_batch_transfer(self, batch, dataloader_idx): + batch['x'] = gpu_transforms(batch['x']) + return batch + + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + + See Also: + - :meth:`on_before_batch_transfer` + - :meth:`transfer_batch_to_device` + """ + return batch + + +class CheckpointHooks: + """Hooks to be used with Checkpointing.""" + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + r""" + Called by Lightning to restore your model. + If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. - # TODO: remove in v0.8.0 - else: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + Args: + checkpoint: Loaded checkpoint + + Example:: + + def on_load_checkpoint(self, checkpoint): + # 99% of the time you don't need to implement this method + self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save'] + + Note: + Lightning auto-restores global step, epoch, and train state including amp scaling. + There is no need for you to restore anything regarding training. + """ + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + r""" + Called by Lightning when saving a checkpoint to give you a chance to store anything + else you might want to save. + + Args: + checkpoint: Checkpoint to be saved + + Example:: + + def on_save_checkpoint(self, checkpoint): + # 99% of use cases you don't need to implement this method + checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object + + Note: + Lightning saves all aspects of training (epoch, global step, etc...) + including amp scaling. + There is no need for you to store anything about training. + + """ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a534929434a8b7..7efe88515b37e3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1,97 +1,474 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""nn.Module with additional great features.""" + import collections +import copy import inspect +import logging import os -from abc import ABC, abstractmethod +import tempfile +import types +import uuid +from abc import ABC from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch -import torch.distributed as torch_distrib -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel +from torch import ScriptModule, Tensor +from torch.nn import Module from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader -from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation -from pytorch_lightning.core.hooks import ModelHooks +from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_warn - -try: - import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - - -class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): +from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args + +log = logging.getLogger(__name__) + + +class LightningModule( + ABC, + DeviceDtypeModuleMixin, + GradInformation, + ModelIO, + ModelHooks, + DataHooks, + CheckpointHooks, + Module, +): + # Below is for property support of JIT in PyTorch 1.7 + # since none of them is important when using JIT, we are going to ignore them. + __jit_unused_properties__ = [ + "datamodule", + "example_input_array", + "hparams", + "hparams_initial", + "on_gpu", + "current_epoch", + "global_step", + "global_rank", + "local_rank", + "logger", + "model_size", + ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - #: Current dtype - self.dtype = torch.FloatTensor + # see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/ + # torch/nn/modules/module.py#L227) + torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}") self.exp_save_path = None - #: The current epoch - self.current_epoch = 0 - - #: Total training batches seen across all epochs - self.global_step = 0 - self.loaded_optimizer_states_dict = {} #: Pointer to the trainer object self.trainer = None - #: Pointer to the logger object - self.logger = None - self.example_input_array = None + self._distrib_type = None + self._device_type = None - #: True if your model is currently running on GPUs. - #: Useful to set flags around the LightningModule for different CPU vs GPU behavior. - self.on_gpu = False + #: True if using amp + self.use_amp = False - #: True if using dp - self.use_dp = False + #: The precision used + self.precision = 32 + + # optionally can be set by user + self._example_input_array = None + self._datamodule = None + self._results: Optional[Result] = None + self._current_fx_name = '' + self._running_manual_backward = False + self._current_hook_fx_name = None + self._current_dataloader_idx = None + self._automatic_optimization: bool = True + self._param_requires_grad_state = dict() + + def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: + if use_pl_optimizer: + opts = list(self.trainer.lightning_optimizers.values()) + else: + opts = self.trainer.optimizers + + # single optimizer + if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer): + return opts[0] + # multiple opts + return opts + + @property + def example_input_array(self) -> Any: + return self._example_input_array + + @property + def current_epoch(self) -> int: + """The current epoch""" + return self.trainer.current_epoch if self.trainer else 0 + + @property + def global_step(self) -> int: + """Total training batches seen across all epochs""" + return self.trainer.global_step if self.trainer else 0 + + @property + def global_rank(self) -> int: + """ The index of the current process across all nodes and devices. """ + return self.trainer.global_rank if self.trainer else 0 + + @property + def local_rank(self) -> int: + """ The index of the current process within a single node. """ + return self.trainer.local_rank if self.trainer else 0 + + @example_input_array.setter + def example_input_array(self, example: Any) -> None: + self._example_input_array = example + + @property + def datamodule(self) -> Any: + return self._datamodule + + @datamodule.setter + def datamodule(self, datamodule: Any) -> None: + self._datamodule = datamodule + + @property + def on_gpu(self): + """ + True if your model is currently running on GPUs. + Useful to set flags around the LightningModule for different CPU vs GPU behavior. + """ + return self.device.type == "cuda" - #: True if using ddp - self.use_ddp = False + @property + def automatic_optimization(self) -> bool: + """ + If False you are responsible for calling .backward, .step, zero_grad. + """ + return self._automatic_optimization - #: True if using ddp2 - self.use_ddp2 = False + @automatic_optimization.setter + def automatic_optimization(self, automatic_optimization: bool) -> None: + self._automatic_optimization = automatic_optimization - #: True if using amp - self.use_amp = False + @property + def logger(self): + """ Reference to the logger object in the Trainer. """ + return self.trainer.logger if self.trainer else None - self.hparams = None + def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0): + batch = self.on_before_batch_transfer(batch, dataloader_idx) + batch = self.transfer_batch_to_device(batch, device) + batch = self.on_after_batch_transfer(batch, dataloader_idx) + return batch def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. Args: - *args: The thing to print. Will be passed to Python's built-in print function. - **kwargs: Will be passed to Python's built-in print function. + *args: The thing to print. The same as for Python's built-in print function. + **kwargs: The same as for Python's built-in print function. - Example: + Example:: - .. code-block:: python + def forward(self, x): + self.print(x, 'in forward') + + """ + if self.trainer.is_global_zero: + progress_bar = self.trainer.progress_bar_callback + if progress_bar is not None and progress_bar.is_enabled: + progress_bar.print(*args, **kwargs) + else: + print(*args, **kwargs) + + def log( + self, + name: str, + value: Any, + prog_bar: bool = False, + logger: bool = True, + on_step: Optional[bool] = None, + on_epoch: Optional[bool] = None, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + add_dataloader_idx: bool = True, + ): + """ + Log a key, value - def forward(self, x): - self.print(x, 'in forward') + Example:: + self.log('train_loss', loss) + + The default behavior per hook is as follows + + .. csv-table:: ``*`` also applies to the test loop + :header: "LightningMoule Hook", "on_step", "on_epoch", "prog_bar", "logger" + :widths: 20, 10, 10, 10, 10 + + "training_step", "T", "F", "F", "T" + "training_step_end", "T", "F", "F", "T" + "training_epoch_end", "F", "T", "F", "T" + "validation_step*", "F", "T", "F", "T" + "validation_step_end*", "F", "T", "F", "T" + "validation_epoch_end*", "F", "T", "F", "T" + + Args: + name: key name + value: value name + prog_bar: if True logs to the progress bar + logger: if True logs to the logger + on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step + on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step + reduce_fx: reduction function over step values for end of epoch. Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_dist: if True, reduces the metric across GPUs/TPUs + sync_dist_op: the op to sync across GPUs/TPUs + sync_dist_group: the ddp group to sync across + add_dataloader_idx: if True, appends the index of the current dataloader to + the name (when using multiple). If False, user needs to give unique names for + each dataloader to not mix values """ - if self.trainer.proc_rank == 0: - print(*args, **kwargs) + if self._results is not None: + # in any epoch end can't log step metrics (only epoch metric) + if 'epoch_end' in self._current_fx_name and on_step: + m = f'on_step=True cannot be used on {self._current_fx_name} method' + raise MisconfigurationException(m) + + if 'epoch_end' in self._current_fx_name and on_epoch is False: + m = f'on_epoch cannot be False when called from the {self._current_fx_name} method' + raise MisconfigurationException(m) + + # add log_dict + # TODO: if logged twice fail with crash + + # set the default depending on the fx_name + on_step = self.__auto_choose_log_on_step(on_step) + on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + + if self._current_hook_fx_name is not None: + self.trainer.logger_connector.check_logging_in_callbacks( + self._current_hook_fx_name, on_step=on_step, on_epoch=on_epoch + ) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException( + f"Logged key: {name} should not contain information about dataloader_idx." + ) + + training_type_plugin = self.trainer.training_type_plugin + + # Determine if dataloader index should be added + dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None + + self._results.log( + name, + value, + prog_bar, + logger, + on_step, + on_epoch, + reduce_fx, + tbptt_reduce_fx, + tbptt_pad_token, + enable_graph, + sync_dist, + sync_dist_op, + sync_dist_group, + training_type_plugin.reduce, + dataloader_idx, + self.device, + ) + + def log_dict( + self, + dictionary: dict, + prog_bar: bool = False, + logger: bool = True, + on_step: Optional[bool] = None, + on_epoch: Optional[bool] = None, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + add_dataloader_idx: bool = True, + ): + """ + Log a dictonary of values at once + + Example:: + + values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n} + self.log_dict(values) + + Args: + dictionary: key value pairs (str, tensors) + prog_bar: if True logs to the progress base + logger: if True logs to the logger + on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step + on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step + reduce_fx: reduction function over step values for end of epoch. Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_dist: if True, reduces the metric across GPUs/TPUs + sync_dist_op: the op to sync across GPUs/TPUs + sync_dist_group: the ddp group sync across + add_dataloader_idx: if True, appends the index of the current dataloader to + the name (when using multiple). If False, user needs to give unique names for + each dataloader to not mix values + """ + for k, v in dictionary.items(): + self.log( + name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + add_dataloader_idx=add_dataloader_idx + ) + + def write_prediction( + self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' + ): + """ + Write predictions to disk using ``torch.save`` + + Example:: + + self.write_prediction('pred', torch.tensor(...), filename='my_predictions.pt') + + Args: + name: a string indicating the name to save the predictions under + value: the predictions, either a single :class:`~torch.Tensor` or a list of them + filename: name of the file to save the predictions to + + Note: + when running in distributed mode, calling ``write_prediction`` will create a file for + each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ... + + """ + self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename) + + def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str = 'predictions.pt'): + """ + Write a dictonary of predictions to disk at once using ``torch.save`` + + Example:: + + pred_dict = {'pred1': torch.tensor(...), 'pred2': torch.tensor(...)} + self.write_prediction_dict(pred_dict) + + Args: + predictions_dict: dict containing predictions, where each prediction should + either be single :class:`~torch.Tensor` or a list of them + + Note: + when running in distributed mode, calling ``write_prediction_dict`` will create a file for + each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ... + + """ + for k, v in predictions_dict.items(): + self.write_prediction(k, v, filename) + + def __auto_choose_log_on_step(self, on_step): + if on_step is None: + if self._current_fx_name in {'training_step', 'training_step_end'}: + on_step = True + elif self._current_fx_name in { + 'evaluation_step', 'evaluation_step_end', 'evaluation_epoch_end', 'training_epoch_end' + }: + on_step = False + else: + on_step = False + + return on_step + + def __auto_choose_log_on_epoch(self, on_epoch): + if on_epoch is None: + if self._current_fx_name in {'training_step', 'training_step_end'}: + on_epoch = False + elif self._current_fx_name in { + 'evaluation_step', 'evaluation_step_end', 'evaluation_epoch_end', 'training_epoch_end' + }: + on_epoch = True + else: + on_epoch = True + + return on_epoch + + def all_gather( + self, + data: Union[torch.Tensor, Dict, List, Tuple], + group: Optional[Any] = None, + sync_grads: bool = False, + ): + r""" + Allows users to call ``self.all_gather()`` from the LightningModule, thus making + the ```all_gather``` operation accelerator agnostic. + + ```all_gather``` is a function provided by accelerators to gather a tensor from several + distributed processes + + Args: + tensor: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...), or if the input was a collection + the output will also be a collection with tensors of this shape. + """ + group = group if group is not None else torch.distributed.group.WORLD + all_gather = self.trainer.accelerator.all_gather + data = convert_to_tensors(data, device=self.device) + all_gather = partial(all_gather, group=group, sync_grads=sync_grads) + return apply_to_collection(data, torch.Tensor, all_gather) - @abstractmethod def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define @@ -101,6 +478,9 @@ def forward(self, *args, **kwargs): This makes it easy to write a complex system for training with the outputs you'd want in a prediction setting. + You may also find the :func:`~pytorch_lightning.core.decorators.auto_move_data` decorator useful + when using the module outside Lightning in a production setting. + Args: *args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible. @@ -108,42 +488,40 @@ def forward(self, *args, **kwargs): Return: Predicted output - Examples: - .. code-block:: python + Examples:: - # example if we were using this model as a feature extractor - def forward(self, x): - feature_maps = self.convnet(x) - return feature_maps + # example if we were using this model as a feature extractor + def forward(self, x): + feature_maps = self.convnet(x) + return feature_maps - def training_step(self, batch, batch_idx): - x, y = batch - feature_maps = self(x) - logits = self.classifier(feature_maps) + def training_step(self, batch, batch_idx): + x, y = batch + feature_maps = self(x) + logits = self.classifier(feature_maps) - # ... - return loss + # ... + return loss - # splitting it this way allows model to be used a feature extractor - model = MyModelAbove() + # splitting it this way allows model to be used a feature extractor + model = MyModelAbove() - inputs = server.get_request() - results = model(inputs) - server.write_results(results) + inputs = server.get_request() + results = model(inputs) + server.write_results(results) - # ------------- - # This is in stark contrast to torch.nn.Module where normally you would have this: - def forward(self, batch): - x, y = batch - feature_maps = self.convnet(x) - logits = self.classifier(feature_maps) - return logits + # ------------- + # This is in stark contrast to torch.nn.Module where normally you would have this: + def forward(self, batch): + x, y = batch + feature_maps = self.convnet(x) + logits = self.classifier(feature_maps) + return logits """ + return super().forward(*args, **kwargs) - def training_step(self, *args, **kwargs) -> Union[ - int, Dict[str, Union[Tensor, Dict[str, Tensor]]] - ]: + def training_step(self, *args, **kwargs): r""" Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. @@ -157,221 +535,162 @@ def training_step(self, *args, **kwargs) -> Union[ :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. Return: - Dict with loss key and optional log or progress bar keys. - When implementing :meth:`training_step`, return whatever you need in that step: + Any of. + + - :class:`~torch.Tensor` - The loss tensor + - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'`` + - ``None`` - Training will skip to the next batch - - loss -> tensor scalar **REQUIRED** - - progress_bar -> Dict for progress bar display. Must have only tensors - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + Note: + Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. - Examples: - .. code-block:: python - - def training_step(self, batch, batch_idx): - x, y, z = batch - - # implement your own - out = self(x) - loss = self.loss(out, x) + Example:: - logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS) + def training_step(self, batch, batch_idx): + x, y, z = batch + out = self.encoder(x) + loss = self.loss(out, x) + return loss - # if using TestTubeLogger or TensorBoardLogger you can nest scalars - logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS) + If you define multiple optimizers, this step will be called with an additional + ``optimizer_idx`` parameter. - output = { - 'loss': loss, # required - 'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS) - 'log': logger_logs - } - - # return a dict - return output - - If you define multiple optimizers, this step will be called with an additional - ``optimizer_idx`` parameter. - - .. code-block:: python - - # Multiple optimizers (e.g.: GANs) - def training_step(self, batch, batch_idx, optimizer_idx): - if optimizer_idx == 0: - # do training_step with encoder - if optimizer_idx == 1: - # do training_step with decoder + .. code-block:: python + # Multiple optimizers (e.g.: GANs) + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + # do training_step with encoder + if optimizer_idx == 1: + # do training_step with decoder - If you add truncated back propagation through time you will also get an additional - argument with the hidden states of the previous step. - .. code-block:: python + If you add truncated back propagation through time you will also get an additional + argument with the hidden states of the previous step. - # Truncated back-propagation through time - def training_step(self, batch, batch_idx, hiddens): - # hiddens are the hidden states from the previous truncated backprop step - ... - out, hiddens = self.lstm(data, hiddens) - ... + .. code-block:: python - return { - "loss": ..., - "hiddens": hiddens # remember to detach() this - } + # Truncated back-propagation through time + def training_step(self, batch, batch_idx, hiddens): + # hiddens are the hidden states from the previous truncated backprop step + ... + out, hiddens = self.lstm(data, hiddens) + ... + return {'loss': loss, 'hiddens': hiddens} - Notes: + Note: The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step. """ - rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer') + rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") - def training_end(self, *args, **kwargs): - """ - Warnings: - Deprecated in v0.7.0. Use :meth:`training_step_end` instead. + def training_step_end(self, *args, **kwargs): """ + Use this when training with dp or ddp2 because :meth:`training_step` + will operate on only part of the batch. However, this is still optional + and only needed for things like softmax or NCE loss. - def training_epoch_end( - self, - outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] - ) -> Dict[str, Dict[str, Tensor]]: - """Called at the end of the training epoch with the outputs of all training steps. + Note: + If you later switch to ddp or some other mode, this will still be called + so that you don't have to change your code .. code-block:: python - # the pseudocode for these calls - train_outs = [] - for train_batch in train_data: - out = training_step(train_batch) - train_outs.append(out) - training_epoch_end(train_outs) + # pseudocode + sub_batches = split_batches_for_dp(batch) + batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches] + training_step_end(batch_parts_outputs) Args: - outputs: List of outputs you defined in :meth:`training_step`, or if there are - multiple dataloaders, a list containing a list of outputs for each dataloader. + batch_parts_outputs: What you return in `training_step` for each batch part. Return: - Dict or OrderedDict. - May contain the following optional keys: - - - log (metrics to be added to the logger; only tensors) - - any metric used in a callback (e.g. early stopping). + Anything - Note: - If this method is not overridden, this won't be called. + When using dp/ddp2 distributed backends, only a portion of the batch is inside the training_step: - - The outputs here are strictly for logging or progress bar. - - If you don't need to display anything, don't return anything. - - If you want to manually set current step, you can specify the 'step' key in the 'log' dict. + .. code-block:: python - Examples: - With a single dataloader: + def training_step(self, batch, batch_idx): + # batch is 1/num_gpus big + x, y = batch - .. code-block:: python + out = self(x) - def training_epoch_end(self, outputs): - train_acc_mean = 0 - for output in outputs: - train_acc_mean += output['train_acc'] + # softmax uses only a portion of the batch in the denomintaor + loss = self.softmax(out) + loss = nce_loss(loss) + return loss - train_acc_mean /= len(outputs) + If you wish to do something with all the parts of the batch, then use this method to do it: - # log training accuracy at the end of an epoch - results = { - 'log': {'train_acc': train_acc_mean.item()} - } - return results + .. code-block:: python - With multiple dataloaders, ``outputs`` will be a list of lists. The outer list contains - one entry per dataloader, while the inner list contains the individual outputs of - each training step for that dataloader. + def training_step(self, batch, batch_idx): + # batch is 1/num_gpus big + x, y = batch - .. code-block:: python + out = self.encoder(x) + return {'pred': out} - def training_epoch_end(self, outputs): - train_acc_mean = 0 - i = 0 - for dataloader_outputs in outputs: - for output in dataloader_outputs: - train_acc_mean += output['train_acc'] - i += 1 + def training_step_end(self, training_step_outputs): + gpu_0_pred = training_step_outputs[0]['pred'] + gpu_1_pred = training_step_outputs[1]['pred'] + gpu_n_pred = training_step_outputs[n]['pred'] - train_acc_mean /= i + # this softmax now uses the full batch + loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred]) + return loss - # log training accuracy at the end of an epoch - results = { - 'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch} - } - return results + See Also: + See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def training_step_end(self, *args, **kwargs) -> Dict[ - str, Union[Tensor, Dict[str, Tensor]] - ]: + def training_epoch_end(self, outputs: List[Any]) -> None: """ - Use this when training with dp or ddp2 because :meth:`training_step` - will operate on only part of the batch. However, this is still optional - and only needed for things like softmax or NCE loss. - - Note: - If you later switch to ddp or some other mode, this will still be called - so that you don't have to change your code + Called at the end of the training epoch with the outputs of all training steps. + Use this in case you need to do something with all the outputs for every training_step. .. code-block:: python - # pseudocode - sub_batches = split_batches_for_dp(batch) - batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches] - training_step_end(batch_parts_outputs) + # the pseudocode for these calls + train_outs = [] + for train_batch in train_data: + out = training_step(train_batch) + train_outs.append(out) + training_epoch_end(train_outs) Args: - batch_parts_outputs: What you return in `training_step` for each batch part. + outputs: List of outputs you defined in :meth:`training_step`, or if there are + multiple dataloaders, a list containing a list of outputs for each dataloader. Return: - Dict with loss key and optional log or progress bar keys. + None - - loss -> tensor scalar **REQUIRED** - - progress_bar -> Dict for progress bar display. Must have only tensors - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) - - Examples: - .. code-block:: python - - # WITHOUT training_step_end - # if used in DP or DDP2, this batch is 1/num_gpus large - def training_step(self, batch, batch_idx): - # batch is 1/num_gpus big - x, y = batch - - out = self(x) - loss = self.softmax(out) - loss = nce_loss(loss) - return {'loss': loss} + Note: + If this method is not overridden, this won't be called. - # -------------- - # with training_step_end to do softmax over the full batch - def training_step(self, batch, batch_idx): - # batch is 1/num_gpus big - x, y = batch + Example:: - out = self(x) - return {'out': out} + def training_epoch_end(self, training_step_outputs): + # do something with all training_step outputs + return result - def training_step_end(self, outputs): - # this out is now the full size of the batch - out = outputs['out'] + With multiple dataloaders, ``outputs`` will be a list of lists. The outer list contains + one entry per dataloader, while the inner list contains the individual outputs of + each training step for that dataloader. - # this softmax now uses the full batch size - loss = nce_loss(loss) - return {'loss': loss} + .. code-block:: python - See Also: - See the :ref:`multi-gpu-training` guide for more details. + def training_epoch_end(self, training_step_outputs): + for out in training_step_outputs: + # do something here """ - def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: + def validation_step(self, *args, **kwargs): r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -381,28 +700,33 @@ def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: # the pseudocode for these calls val_outs = [] for val_batch in val_data: - out = validation_step(train_batch) + out = validation_step(val_batch) val_outs.append(out) - validation_epoch_end(val_outs) + validation_epoch_end(val_outs) Args: batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (int): The index of this batch dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple val datasets used) + (only if multiple val dataloaders used) Return: - Dict or OrderedDict - passed to :meth:`validation_epoch_end`. - If you defined :meth:`validation_step_end` it will go to that first. + Any of. + + - Any object or value + - ``None`` - Validation will skip to the next batch .. code-block:: python # pseudocode of order - out = validation_step() - if defined('validation_step_end'): - out = validation_step_end(out) - out = validation_epoch_end(out) + val_outs = [] + for val_batch in val_data: + out = validation_step(val_batch) + if defined('validation_step_end'): + out = validation_step_end(out) + val_outs.append(out) + val_outs = validation_epoch_end(val_outs) .. code-block:: python @@ -413,44 +737,36 @@ def validation_step(self, batch, batch_idx) # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx) - Examples: - .. code-block:: python + Examples:: - # CASE 1: A single validation dataset - def validation_step(self, batch, batch_idx): - x, y = batch + # CASE 1: A single validation dataset + def validation_step(self, batch, batch_idx): + x, y = batch - # implement your own - out = self(x) - loss = self.loss(out, y) + # implement your own + out = self(x) + loss = self.loss(out, y) - # log 6 example images - # or generated text... or whatever - sample_imgs = x[:6] - grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image('example_images', grid, 0) + # log 6 example images + # or generated text... or whatever + sample_imgs = x[:6] + grid = torchvision.utils.make_grid(sample_imgs) + self.logger.experiment.add_image('example_images', grid, 0) - # calculate acc - labels_hat = torch.argmax(out, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + # calculate acc + labels_hat = torch.argmax(out, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - # all optional... - # return whatever you need for the collation function validation_epoch_end - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': torch.tensor(val_acc), # everything must be a tensor - }) + # log the outputs! + self.log_dict({'val_loss': loss, 'val_acc': val_acc}) - # return an optional dict - return output + If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. - If you pass in multiple val datasets, validation_step will have an additional argument. - - .. code-block:: python + .. code-block:: python - # CASE 2: multiple validation datasets - def validation_step(self, batch, batch_idx, dataset_idx): - # dataset_idx tells you which dataset this is. + # CASE 2: multiple validation dataloaders + def validation_step(self, batch, batch_idx, dataloader_idx): + # dataloader_idx tells you which dataset this is. Note: If you don't need to validate you don't need to implement this method. @@ -461,7 +777,7 @@ def validation_step(self, batch, batch_idx, dataset_idx): the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: + def validation_step_end(self, *args, **kwargs): """ Use this when validating with dp or ddp2 because :meth:`validation_step` will operate on only part of the batch. However, this is still optional @@ -483,54 +799,39 @@ def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: for each batch part. Return: - Dict or OrderedDict - passed to the :meth:`validation_epoch_end` method. + None or anything - Examples: - .. code-block:: python - - # WITHOUT validation_step_end - # if used in DP or DDP2, this batch is 1/num_gpus large - def validation_step(self, batch, batch_idx): - # batch is 1/num_gpus big - x, y = batch + .. code-block:: python - out = self(x) - loss = self.softmax(out) - loss = nce_loss(loss) - return {'loss': loss} + # WITHOUT validation_step_end + # if used in DP or DDP2, this batch is 1/num_gpus large + def validation_step(self, batch, batch_idx): + # batch is 1/num_gpus big + x, y = batch - # -------------- - # with validation_step_end to do softmax over the full batch - def validation_step(self, batch, batch_idx): - # batch is 1/num_gpus big - x, y = batch + out = self.encoder(x) + loss = self.softmax(out) + loss = nce_loss(loss) + self.log('val_loss', loss) - out = self(x) - return {'out': out} + # -------------- + # with validation_step_end to do softmax over the full batch + def validation_step(self, batch, batch_idx): + # batch is 1/num_gpus big + x, y = batch - def validation_epoch_end(self, outputs): - # this out is now the full size of the batch - out = outputs['out'] + out = self(x) + return out - # this softmax now uses the full batch size - loss = nce_loss(loss) - return {'loss': loss} + def validation_step_end(self, val_step_outputs): + for out in val_step_outputs: + # do something with these See Also: - See the :ref:`multi-gpu-training` guide for more details. + See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def validation_end(self, outputs): - """ - Warnings: - Deprecated in v0.7.0. Use :meth:`validation_epoch_end` instead. - Will be removed in 1.0.0. - """ - - def validation_epoch_end( - self, - outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] - ) -> Dict[str, Dict[str, Tensor]]: + def validation_epoch_end(self, outputs: List[Any]) -> None: """ Called at the end of the validation epoch with the outputs of all validation steps. @@ -548,38 +849,19 @@ def validation_epoch_end( are multiple dataloaders, a list containing a list of outputs for each dataloader. Return: - Dict or OrderedDict. - May have the following optional keys: - - - progress_bar (dict for progress bar display; only tensors) - - log (dict of metrics to add to logger; only tensors). + None Note: If you didn't define a :meth:`validation_step`, this won't be called. - - The outputs here are strictly for logging or progress bar. - - If you don't need to display anything, don't return anything. - - If you want to manually set current step, you can specify the 'step' key in the 'log' dict. - Examples: With a single dataloader: .. code-block:: python - def validation_epoch_end(self, outputs): - val_acc_mean = 0 - for output in outputs: - val_acc_mean += output['val_acc'] - - val_acc_mean /= len(outputs) - tqdm_dict = {'val_acc': val_acc_mean.item()} - - # show val_acc in progress bar but only log val_loss - results = { - 'progress_bar': tqdm_dict, - 'log': {'val_acc': val_acc_mean.item()} - } - return results + def validation_epoch_end(self, val_step_outputs): + for out in val_step_outputs: + # do something With multiple dataloaders, `outputs` will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of @@ -588,25 +870,13 @@ def validation_epoch_end(self, outputs): .. code-block:: python def validation_epoch_end(self, outputs): - val_acc_mean = 0 - i = 0 - for dataloader_outputs in outputs: - for output in dataloader_outputs: - val_acc_mean += output['val_acc'] - i += 1 - - val_acc_mean /= i - tqdm_dict = {'val_acc': val_acc_mean.item()} - - # show val_loss and val_acc in progress bar but only log val_loss - results = { - 'progress_bar': tqdm_dict, - 'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch} - } - return results + for dataloader_output_result in outputs: + dataloader_outs = dataloader_output_result.dataloader_i_outputs + + self.log('final_metric', final_value) """ - def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: + def test_step(self, *args, **kwargs): r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest @@ -626,11 +896,13 @@ def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (int): The index of this batch. dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple test datasets used). + (only if multiple test dataloaders used). Return: - Dict or OrderedDict - passed to the :meth:`test_epoch_end` method. - If you defined :meth:`test_step_end` it will go to that first. + Any of. + + - Any object or value + - ``None`` - Testing will skip to the next batch .. code-block:: python @@ -640,48 +912,39 @@ def test_step(self, batch, batch_idx) # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx) - Examples: - .. code-block:: python - - # CASE 1: A single test dataset - def test_step(self, batch, batch_idx): - x, y = batch + Examples:: - # implement your own - out = self(x) - loss = self.loss(out, y) + # CASE 1: A single test dataset + def test_step(self, batch, batch_idx): + x, y = batch - # log 6 example images - # or generated text... or whatever - sample_imgs = x[:6] - grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image('example_images', grid, 0) + # implement your own + out = self(x) + loss = self.loss(out, y) - # calculate acc - labels_hat = torch.argmax(out, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + # log 6 example images + # or generated text... or whatever + sample_imgs = x[:6] + grid = torchvision.utils.make_grid(sample_imgs) + self.logger.experiment.add_image('example_images', grid, 0) - # all optional... - # return whatever you need for the collation function test_epoch_end - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': torch.tensor(val_acc), # everything must be a tensor - }) + # calculate acc + labels_hat = torch.argmax(out, dim=1) + test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - # return an optional dict - return output + # log the outputs! + self.log_dict({'test_loss': loss, 'test_acc': test_acc}) - If you pass in multiple validation datasets, :meth:`test_step` will have an additional - argument. + If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. - .. code-block:: python + .. code-block:: python - # CASE 2: multiple test datasets - def test_step(self, batch, batch_idx, dataset_idx): - # dataset_idx tells you which dataset this is. + # CASE 2: multiple test dataloaders + def test_step(self, batch, batch_idx, dataloader_idx): + # dataloader_idx tells you which dataset this is. Note: - If you don't need to validate you don't need to implement this method. + If you don't need to test you don't need to implement this method. Note: When the :meth:`test_step` is called, the model has been put in eval mode and @@ -689,7 +952,7 @@ def test_step(self, batch, batch_idx, dataset_idx): to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: + def test_step_end(self, *args, **kwargs): """ Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch. However, this is still optional @@ -710,54 +973,40 @@ def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: batch_parts_outputs: What you return in :meth:`test_step` for each batch part. Return: - Dict or OrderedDict - passed to the :meth:`test_epoch_end`. + None or anything - Examples: - .. code-block:: python - - # WITHOUT test_step_end - # if used in DP or DDP2, this batch is 1/num_gpus large - def test_step(self, batch, batch_idx): - # batch is 1/num_gpus big - x, y = batch + .. code-block:: python - out = self(x) - loss = self.softmax(out) - loss = nce_loss(loss) - return {'loss': loss} + # WITHOUT test_step_end + # if used in DP or DDP2, this batch is 1/num_gpus large + def test_step(self, batch, batch_idx): + # batch is 1/num_gpus big + x, y = batch - # -------------- - # with test_step_end to do softmax over the full batch - def test_step(self, batch, batch_idx): - # batch is 1/num_gpus big - x, y = batch + out = self(x) + loss = self.softmax(out) + self.log('test_loss', loss) - out = self(x) - return {'out': out} + # -------------- + # with test_step_end to do softmax over the full batch + def test_step(self, batch, batch_idx): + # batch is 1/num_gpus big + x, y = batch - def test_step_end(self, outputs): - # this out is now the full size of the batch - out = outputs['out'] + out = self.encoder(x) + return out - # this softmax now uses the full batch size - loss = nce_loss(loss) - return {'loss': loss} + def test_step_end(self, output_results): + # this out is now the full size of the batch + all_test_step_outs = output_results.out + loss = nce_loss(all_test_step_outs) + self.log('test_loss', loss) See Also: - See the :ref:`multi-gpu-training` guide for more details. + See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def test_end(self, outputs): - """ - Warnings: - Deprecated in v0.7.0. Use :meth:`test_epoch_end` instead. - Will be removed in 1.0.0. - """ - - def test_epoch_end( - self, - outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] - ) -> Dict[str, Dict[str, Tensor]]: + def test_epoch_end(self, outputs: List[Any]) -> None: """ Called at the end of a test epoch with the output of all test steps. @@ -775,37 +1024,22 @@ def test_epoch_end( are multiple dataloaders, a list containing a list of outputs for each dataloader Return: - Dict or OrderedDict: Dict has the following optional keys: - - - progress_bar -> Dict for progress bar display. Must have only tensors. - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc). + None Note: If you didn't define a :meth:`test_step`, this won't be called. - - The outputs here are strictly for logging or progress bar. - - If you don't need to display anything, don't return anything. - - If you want to manually set current step, specify it with the 'step' key in the 'log' Dict - Examples: With a single dataloader: .. code-block:: python def test_epoch_end(self, outputs): - test_acc_mean = 0 - for output in outputs: - test_acc_mean += output['test_acc'] - - test_acc_mean /= len(outputs) - tqdm_dict = {'test_acc': test_acc_mean.item()} + # do something with the outputs of all test batches + all_test_preds = test_step_outputs.predictions - # show test_loss and test_acc in progress bar but only log test_loss - results = { - 'progress_bar': tqdm_dict, - 'log': {'test_acc': test_acc_mean.item()} - } - return results + some_result = calc_all_results(all_test_preds) + self.log(some_result) With multiple dataloaders, `outputs` will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of @@ -814,176 +1048,48 @@ def test_epoch_end(self, outputs): .. code-block:: python def test_epoch_end(self, outputs): - test_acc_mean = 0 - i = 0 + final_value = 0 for dataloader_outputs in outputs: - for output in dataloader_outputs: - test_acc_mean += output['test_acc'] - i += 1 - - test_acc_mean /= i - tqdm_dict = {'test_acc': test_acc_mean.item()} - - # show test_loss and test_acc in progress bar but only log test_loss - results = { - 'progress_bar': tqdm_dict, - 'log': {'test_acc': test_acc_mean.item(), 'step': self.current_epoch} - } - return results - """ - - def configure_ddp( - self, - model: 'LightningModule', - device_ids: List[int] - ) -> DistributedDataParallel: - r""" - Override to init DDP in your own way or with your own wrapper. - The only requirements are that: - - 1. On a validation batch the call goes to ``model.validation_step``. - 2. On a training batch the call goes to ``model.training_step``. - 3. On a testing batch, the call goes to ``model.test_step``.+ - - Args: - model: the :class:`LightningModule` currently being optimized. - device_ids: the list of GPU ids. - - Return: - DDP wrapped model - - Examples: - .. code-block:: python - - # default implementation used in Trainer - def configure_ddp(self, model, device_ids): - # Lightning DDP simply routes to test_step, val_step, etc... - model = LightningDistributedDataParallel( - model, - device_ids=device_ids, - find_unused_parameters=True - ) - return model + for test_step_out in dataloader_outputs: + # do something + final_value += test_step_out + self.log('final_metric', final_value) """ - model = LightningDistributedDataParallel( - model, - device_ids=device_ids, - find_unused_parameters=True - ) - return model - - def _init_slurm_connection(self) -> None: + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): """ - Sets up environemnt variables necessary for pytorch distributed communications - based on slurm environment. + Use this function with trainer.predict(...). Override if you need to add any processing logic. """ - # use slurm job id for the port number - # guarantees unique ports across jobs from same grid search - try: - # use the last 4 numbers in the job id as the id - default_port = os.environ['SLURM_JOB_ID'] - default_port = default_port[-4:] - - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - - except Exception: - default_port = 12910 - - # if user gave a port number, use that one instead - try: - default_port = os.environ['MASTER_PORT'] - except Exception: - os.environ['MASTER_PORT'] = str(default_port) - - # figure out the root node addr - try: - root_node = os.environ['SLURM_NODELIST'].split(' ')[0] - except Exception: - root_node = '127.0.0.1' - - root_node = self.trainer.resolve_root_node_address(root_node) - os.environ['MASTER_ADDR'] = root_node - - def init_ddp_connection( - self, - proc_rank: int, - world_size: int, - is_slurm_managing_tasks: bool = True - ) -> None: - """ - Override to define your custom way of setting up a distributed environment. - - Lightning's implementation uses env:// init by default and sets the first node as root - for SLURM managed cluster. - - Args: - proc_rank: The current process rank within the node. - world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). - is_slurm_managing_tasks: is cluster managed by SLURM. + return self(batch) + def configure_callbacks(self): """ - if is_slurm_managing_tasks: - self._init_slurm_connection() - - if 'MASTER_ADDR' not in os.environ: - log.warning("MASTER_ADDR environment variable is not defined. Set as localhost") - os.environ['MASTER_ADDR'] = '127.0.0.1' - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - - if 'MASTER_PORT' not in os.environ: - log.warning("MASTER_PORT environment variable is not defined. Set as 12910") - os.environ['MASTER_PORT'] = '12910' - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - - if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size: - log.warning("WORLD_SIZE environment variable is not equal to the computed " - "world size. Ignored.") - - torch_backend = "nccl" if self.trainer.on_gpu else "gloo" - torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size) - - def configure_apex( - self, - amp: object, - model: 'LightningModule', - optimizers: List[Optimizer], - amp_level: str - ) -> Tuple['LightningModule', List[Optimizer]]: - r""" - Override to init AMP your own way. - Must return a model and list of optimizers. - - Args: - amp: pointer to amp library object. - model: pointer to current :class:`LightningModule`. - optimizers: list of optimizers passed in :meth:`configure_optimizers`. - amp_level: AMP mode chosen ('O1', 'O2', etc...) + Configure model-specific callbacks. + When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, + the list returned here will be merged with the list of callbacks passed to the Trainer's ``callbacks`` argument. + If a callback returned here has the same type as one or several callbacks already present in + the Trainer's callbacks list, it will take priority and replace them. + In addition, Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + callbacks run last. Return: - Apex wrapped model and optimizers + A list of callbacks which will extend the list of callbacks in the Trainer. - Examples: - .. code-block:: python + Example:: - # Default implementation used by Trainer. - def configure_apex(self, amp, model, optimizers, amp_level): - model, optimizers = amp.initialize( - model, optimizers, opt_level=amp_level, - ) + def configure_callbacks(self): + early_stop = EarlyStopping(monitor"val_acc", mode="max") + checkpoint = ModelCheckpoint(monitor="val_loss") + return [early_stop, checkpoint] - return model, optimizers + Note: + Certain callback methods like :meth:`~pytorch_lightning.callbacks.base.Callback.on_init_start` + will never be invoked on the new callbacks returned here. """ - model, optimizers = amp.initialize( - model, optimizers, opt_level=amp_level, - ) - - return model, optimizers + return [] - def configure_optimizers(self) -> Optional[Union[ - Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List] - ]]: + def configure_optimizers(self): r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. @@ -993,8 +1099,9 @@ def configure_optimizers(self) -> Optional[Union[ - Single optimizer. - List or Tuple - List of optimizers. - - Two lists - The first list has multiple optimizers, the second a list of LR schedulers. - - Dictionary, with an 'optimizer' key and (optionally) a 'lr_scheduler' key. + - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict). + - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' + key whose value is a single LR scheduler or lr_dict. - Tuple of dictionaries as described, with an optional 'frequency' key. - None - Fit will run without any optimizer. @@ -1006,47 +1113,63 @@ def configure_optimizers(self) -> Optional[Union[ In the former case, all optimizers will operate on the given batch in each optimization step. In the latter, only one optimizer will operate on the given batch at every step. - Examples: + The lr_dict is a dictionary which contains the scheduler and its associated configuration. + The default configuration is shown below. + .. code-block:: python - # most cases - def configure_optimizers(self): - opt = Adam(self.parameters(), lr=1e-3) - return opt - - # multiple optimizer case (e.g.: GAN) - def configure_optimizers(self): - generator_opt = Adam(self.model_gen.parameters(), lr=0.01) - disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) - return generator_opt, disriminator_opt - - # example with learning rate schedulers - def configure_optimizers(self): - generator_opt = Adam(self.model_gen.parameters(), lr=0.01) - disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) - discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) - return [generator_opt, disriminator_opt], [discriminator_sched] - - # example with step-based learning rate schedulers - def configure_optimizers(self): - gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_disc.parameters(), lr=0.02) - gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99), - 'interval': 'step'} # called after each training step - dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch - return [gen_opt, dis_opt], [gen_sched, dis_sched] - - # example with optimizer frequencies - # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 - # https://arxiv.org/abs/1704.00028 - def configure_optimizers(self): - gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_disc.parameters(), lr=0.02) - n_critic = 5 - return ( - {'optimizer': dis_opt, 'frequency': n_critic}, - {'optimizer': gen_opt, 'frequency': 1} - ) + { + 'scheduler': lr_scheduler, # The LR scheduler instance (required) + 'interval': 'epoch', # The unit of the scheduler's step size + 'frequency': 1, # The frequency of the scheduler + 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler + 'monitor': 'val_loss', # Metric for ReduceLROnPlateau to monitor + 'strict': True, # Whether to crash the training if `monitor` is not found + 'name': None, # Custom name for LearningRateMonitor to use + } + + Only the ``scheduler`` key is required, the rest will be set to the defaults above. + + Examples:: + + # most cases + def configure_optimizers(self): + opt = Adam(self.parameters(), lr=1e-3) + return opt + + # multiple optimizer case (e.g.: GAN) + def configure_optimizers(self): + generator_opt = Adam(self.model_gen.parameters(), lr=0.01) + disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) + return generator_opt, disriminator_opt + + # example with learning rate schedulers + def configure_optimizers(self): + generator_opt = Adam(self.model_gen.parameters(), lr=0.01) + disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) + discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) + return [generator_opt, disriminator_opt], [discriminator_sched] + + # example with step-based learning rate schedulers + def configure_optimizers(self): + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_disc.parameters(), lr=0.02) + gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99), + 'interval': 'step'} # called after each training step + dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch + return [gen_opt, dis_opt], [gen_sched, dis_sched] + + # example with optimizer frequencies + # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 + # https://arxiv.org/abs/1704.00028 + def configure_optimizers(self): + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_disc.parameters(), lr=0.02) + n_critic = 5 + return ( + {'optimizer': dis_opt, 'frequency': n_critic}, + {'optimizer': gen_opt, 'frequency': 1} + ) Note: @@ -1070,27 +1193,139 @@ def configure_optimizers(self): default ``.step()`` schedule, override the :meth:`optimizer_step` hook. - If you only want to call a learning rate scheduler every ``x`` step or epoch, - or want to monitor a custom metric, you can specify these in a dictionary: + or want to monitor a custom metric, you can specify these in a lr_dict: .. code-block:: python { 'scheduler': lr_scheduler, - 'interval': 'step' # or 'epoch' + 'interval': 'step', # or 'epoch' 'monitor': 'val_f1', - 'frequency': x + 'frequency': x, } """ - rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer') + rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") + + def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None: + """ + Call this directly from your training_step when doing optimizations manually. + By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you + + This function forwards all args to the .backward() call as well. + + .. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set + + .. tip:: In manual mode we still automatically accumulate grad over batches if + Trainer(accumulate_grad_batches=x) is set and you use `optimizer.step()` + + Example:: + + def training_step(...): + opt_a, opt_b = self.optimizers() + loss = ... + # automatically applies scaling, etc... + self.manual_backward(loss) + opt_a.step() + """ + if optimizer is not None: + rank_zero_deprecation( + "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" + ) + + # make sure we're using manual opt + self._verify_is_manual_optimization('manual_backward') + + # backward + self._running_manual_backward = True + self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self._running_manual_backward = False + + def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + """ + Override backward with your own implementation if you need to. + + Args: + loss: Loss is already scaled by accumulated grads + optimizer: Current optimizer being used + optimizer_idx: Index of the current optimizer being used + + Called to perform backward step. + Feel free to override as needed. + The loss passed in has already been scaled for accumulated gradients if requested. + + Example:: + + def backward(self, loss, optimizer, optimizer_idx): + loss.backward() + + """ + if self.trainer.train_loop.automatic_optimization or self._running_manual_backward: + loss.backward(*args, **kwargs) + + def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): + """ + Makes sure only the gradients of the current optimizer's parameters are calculated + in the training step to prevent dangling gradients in multiple-optimizer setup. + + .. note:: Only called when using multiple optimizers + + Override for your own behavior + + It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset. + + Args: + optimizer: Current optimizer used in training_loop + optimizer_idx: Current optimizer idx in training_loop + """ + + # Iterate over all optimizer parameters to preserve their `requires_grad` information + # in case these are pre-defined during `configure_optimizers` + param_requires_grad_state = {} + for opt in self.optimizers(use_pl_optimizer=False): + for group in opt.param_groups: + for param in group['params']: + # If a param already appear in param_requires_grad_state, continue + if param in param_requires_grad_state: + continue + param_requires_grad_state[param] = param.requires_grad + param.requires_grad = False + + # Then iterate over the current optimizer's parameters and set its `requires_grad` + # properties accordingly + for group in optimizer.param_groups: + for param in group['params']: + param.requires_grad = param_requires_grad_state[param] + self._param_requires_grad_state = param_requires_grad_state + + def untoggle_optimizer(self, optimizer_idx: int): + """ + .. note:: Only called when using multiple optimizers + + Override for your own behavior + + Args: + optimizer_idx: Current optimizer idx in training_loop + """ + for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): + if optimizer_idx != opt_idx: + for group in opt.param_groups: + for param in group['params']: + if param in self._param_requires_grad_state: + param.requires_grad = self._param_requires_grad_state[param] + # save memory + self._param_requires_grad_state = dict() def optimizer_step( - self, - epoch: int, - batch_idx: int, - optimizer: Optimizer, - optimizer_idx: int, - second_order_closure: Optional[Callable] = None, + self, + epoch: int = None, + batch_idx: int = None, + optimizer: Optimizer = None, + optimizer_idx: int = None, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = None, + using_native_amp: bool = None, + using_lbfgs: bool = None, ) -> None: r""" Override this method to adjust the default way the @@ -1098,88 +1333,73 @@ def optimizer_step( By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. + Warning: + If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter + to ``optimizer.step()`` function as shown in the examples. This ensures that + ``train_step_and_backward_closure`` is called within + :meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`. + Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer optimizer_idx: If you used multiple optimizers this indexes into that list. - second_order_closure: closure for second order methods + optimizer_closure: closure for all optimizers + on_tpu: true if TPU backward is required + using_native_amp: True if using native amp + using_lbfgs: True if the matching optimizer is lbfgs - Examples: - .. code-block:: python - - # DEFAULT - def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, - second_order_closure=None): - optimizer.step() - optimizer.zero_grad() + Examples:: - # Alternating schedule for optimizer steps (i.e.: GANs) - def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, - second_order_closure=None): - # update generator opt every 2 steps - if optimizer_idx == 0: - if batch_idx % 2 == 0 : - optimizer.step() - optimizer.zero_grad() + # DEFAULT + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + optimizer.step(closure=optimizer_closure) - # update discriminator opt every 4 steps - if optimizer_idx == 1: - if batch_idx % 4 == 0 : - optimizer.step() - optimizer.zero_grad() + # Alternating schedule for optimizer steps (i.e.: GANs) + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # update generator opt every 2 steps + if optimizer_idx == 0: + if batch_idx % 2 == 0 : + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() - # ... - # add as many optimizers as you want + # update discriminator opt every 4 steps + if optimizer_idx == 1: + if batch_idx % 4 == 0 : + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + # ... + # add as many optimizers as you want - Here's another example showing how to use this for more advanced things such as - learning rate warm-up: - .. code-block:: python + Here's another example showing how to use this for more advanced things such as + learning rate warm-up: - # learning rate warm-up - def optimizer_step(self, current_epoch, batch_idx, optimizer, - optimizer_idx, second_order_closure=None): - # warm up lr - if self.trainer.global_step < 500: - lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) - for pg in optimizer.param_groups: - pg['lr'] = lr_scale * self.hparams.learning_rate + .. code-block:: python - # update params - optimizer.step() - optimizer.zero_grad() + # learning rate warm-up + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * self.learning_rate - Note: - If you also override the :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad` - model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself. + # update params + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() """ - if self.trainer.use_tpu and XLA_AVAILABLE: - xm.optimizer_step(optimizer) - elif isinstance(optimizer, torch.optim.LBFGS): - - # native amp + lbfgs is a no go right now - if self.trainer.use_amp and self.trainer.use_native_amp: - raise MisconfigurationException( - 'native PyTorch amp and lbfgs are not compatible.' - ' To request, please file a Github issue in PyTorch and tag @mcarilli') - optimizer.step(second_order_closure) - else: - if self.trainer.use_amp and self.trainer.use_native_amp: - self.trainer.scaler.step(optimizer) - else: - optimizer.step() - - # in native 16-bit we need to update scaler after optimizer step - if self.trainer.use_amp and self.trainer.use_native_amp: - self.trainer.scaler.update() + if not isinstance(optimizer, LightningOptimizer): + # wraps into LightingOptimizer only for running step + optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx) + optimizer.step(closure=optimizer_closure) - # model hook - self.on_before_zero_grad(optimizer) - - # clear gradients + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): optimizer.zero_grad() def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: @@ -1197,26 +1417,25 @@ def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length. - Examples: - .. code-block:: python + Examples:: - def tbptt_split_batch(self, batch, split_size): - splits = [] - for t in range(0, time_dims[0], split_size): - batch_split = [] - for i, x in enumerate(batch): - if isinstance(x, torch.Tensor): - split_x = x[:, t:t + split_size] - elif isinstance(x, collections.Sequence): - split_x = [None] * len(x) - for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t:t + split_size] + def tbptt_split_batch(self, batch, split_size): + splits = [] + for t in range(0, time_dims[0], split_size): + batch_split = [] + for i, x in enumerate(batch): + if isinstance(x, torch.Tensor): + split_x = x[:, t:t + split_size] + elif isinstance(x, collections.Sequence): + split_x = [None] * len(x) + for batch_idx in range(len(x)): + split_x[batch_idx] = x[batch_idx][t:t + split_size] - batch_split.append(split_x) + batch_split.append(split_x) - splits.append(batch_split) + splits.append(batch_split) - return splits + return splits Note: Called in the training loop after @@ -1246,432 +1465,390 @@ def tbptt_split_batch(self, batch, split_size): return splits - def prepare_data(self) -> None: - """ - Use this to download and prepare data. - In distributed (GPU, TPU), this will only be called once. - This is called before requesting the dataloaders: + def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]: + model_summary = None - .. code-block:: python + if mode in ModelSummary.MODES: + model_summary = ModelSummary(self, mode=mode) + log.info("\n" + str(model_summary)) + elif mode is not None: + raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") - model.prepare_data() - model.train_dataloader() - model.val_dataloader() - model.test_dataloader() + return model_summary - Examples: - .. code-block:: python - - def prepare_data(self): - download_imagenet() - clean_imagenet() - cache_imagenet() - """ - - def train_dataloader(self) -> DataLoader: - """ - Implement a PyTorch DataLoader for training. + def freeze(self) -> None: + r""" + Freeze all params for inference. - Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. + Example:: - The dataloader you return will not be called every epoch unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. + model = MyLightningModule(...) + model.freeze() - It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. + """ + for param in self.parameters(): + param.requires_grad = False - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... - - :meth:`prepare_data` - - :meth:`train_dataloader` + self.eval() - Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. - There is no need to set it yourself. + def unfreeze(self) -> None: + """ + Unfreeze all parameters for training. - Example: - .. code-block:: python + .. code-block:: python - def train_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, - download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.hparams.batch_size, - shuffle=True - ) - return loader + model = MyLightningModule(...) + model.unfreeze() """ - rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') + for param in self.parameters(): + param.requires_grad = True - def tng_dataloader(self): # todo: remove in v1.0.0 - """ - Warnings: - Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0. - """ - output = self.train_dataloader() - rank_zero_warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." - " and this method will be removed in v1.0.0", DeprecationWarning) - return output + self.train() - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" - Implement one or multiple PyTorch DataLoaders for testing. - - The dataloader you return will not be called every epoch unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. - - It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... - - :meth:`prepare_data` - - :meth:`train_dataloader` - - :meth:`val_dataloader` - - :meth:`test_dataloader` - - Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. - There is no need to set it yourself. + Implement this to override the default items displayed in the progress bar. + By default it includes the average loss value, split index of BPTT (if used) + and the version of the experiment when using a logger. - Return: - Single or multiple PyTorch DataLoaders. + .. code-block:: - Example: - .. code-block:: python + Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10] - def test_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, - download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.hparams.batch_size, - shuffle=False - ) - - return loader + Here is an example how to override the defaults: - Note: - If you don't need a test dataset and a :meth:`test_step`, you don't need to implement - this method. - - """ - - def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: - r""" - Implement one or multiple PyTorch DataLoaders for validation. - - The dataloader you return will not be called every epoch unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. - - It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... - - :meth:`prepare_data` - - :meth:`train_dataloader` - - :meth:`val_dataloader` - - :meth:`test_dataloader` + .. code-block:: python - Note: - Lightning adds the correct sampler for distributed and arbitrary hardware - There is no need to set it yourself. + def get_progress_bar_dict(self): + # don't show the version number + items = super().get_progress_bar_dict() + items.pop("v_num", None) + return items Return: - Single or multiple PyTorch DataLoaders. - - Examples: - .. code-block:: python + Dictionary with the items to be displayed in the progress bar. + """ + # call .item() only once but store elements without graphs + running_train_loss = self.trainer.train_loop.running_loss.mean() + avg_training_loss = None + if running_train_loss is not None: + avg_training_loss = running_train_loss.cpu().item() + elif self.trainer.train_loop.automatic_optimization: + avg_training_loss = float('NaN') - def val_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=False, - transform=transform, download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.hparams.batch_size, - shuffle=False - ) + tqdm_dict = {} + if avg_training_loss is not None: + tqdm_dict["loss"] = f"{avg_training_loss:.3g}" - return loader + if self.trainer.truncated_bptt_steps is not None: + tqdm_dict["split_idx"] = self.trainer.split_idx - # can also return multiple dataloaders - def val_dataloader(self): - return [loader_a, loader_b, ..., loader_n] + if self.trainer.logger is not None and self.trainer.logger.version is not None: + version = self.trainer.logger.version + # show last 4 places of long version strings + version = version[-4:] if isinstance(version, str) else version + tqdm_dict["v_num"] = version - Note: - If you don't need a validation dataset and a :meth:`validation_step`, you don't need to - implement this method. + return tqdm_dict - Note: - In the case where you return multiple validation dataloaders, the :meth:`validation_step` - will have an argument ``dataset_idx`` which matches the order here. - """ + def _verify_is_manual_optimization(self, fn_name): + if self.trainer.train_loop.automatic_optimization: + raise MisconfigurationException( + f'to use {fn_name}, please disable automatic optimization:' + ' set model property `automatic_optimization` as False' + ) @classmethod - def load_from_metrics(cls, weights_path, tags_csv, map_location=None): - r""" - Warning: - Deprecated in version 0.7.0. You should use :meth:`load_from_checkpoint` instead. - Will be removed in v0.9.0. + def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: """ - rank_zero_warn( - "`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0." - " The deprecated method will be removed in v0.9.0.", DeprecationWarning - ) - return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location) - - @classmethod - def load_from_checkpoint( - cls, - checkpoint_path: str, - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, - tags_csv: Optional[str] = None, - *args, **kwargs - ) -> 'LightningModule': - r""" - Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint - it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule` - with an argument called ``hparams`` which is a :class:`~argparse.Namespace` - (output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments). - Any other arguments specified through \*args and \*\*kwargs will be passed to the model. - - Example: - .. code-block:: python - - from argparse import Namespace - hparams = Namespace(**{'learning_rate': 0.1}) - - model = MyModel(hparams) - - class MyModel(LightningModule): - def __init__(self, hparams): - self.learning_rate = hparams.learning_rate + Collect all module arguments in the current constructor and all child constructors. + The child constructors are all the ``__init__`` methods that reach the current class through + (chained) ``super().__init__()`` calls. Args: - checkpoint_path: Path to checkpoint. - model_args: Any keyword args needed to init the model. - map_location: - If your checkpoint saved a GPU model and you now load on CPUs - or a different number of GPUs, use this to map to the new setup. - The behaviour is the same as in :func:`torch.load`. - tags_csv: Optional path to a .csv file with two columns (key, value) - as in this example:: - - key,value - drop_prob,0.2 - batch_size,32 - - You most likely won't need this since Lightning will always save the hyperparameters - to the checkpoint. - However, if your checkpoint weights don't have the hyperparameters saved, - use this method to pass in a .csv file with the hparams you'd like to use. - These will be converted into a :class:`~argparse.Namespace` and passed into your - :class:`LightningModule` for use. - - Return: - :class:`LightningModule` with loaded weights and hyperparameters (if available). - - Example: - .. code-block:: python - - # load weights without mapping ... - MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') - - # or load weights mapping all weights from GPU 1 to GPU 0 ... - map_location = {'cuda:1':'cuda:0'} - MyLightningModule.load_from_checkpoint( - 'path/to/checkpoint.ckpt', - map_location=map_location - ) - - # or load weights and hyperparameters from separate files. - MyLightningModule.load_from_checkpoint( - 'path/to/checkpoint.ckpt', - tags_csv='/path/to/hparams_file.csv' - ) + frame: instance frame - # or load passing whatever args the model takes to load - MyLightningModule.load_from_checkpoint( - 'path/to/checkpoint.ckpt', - learning_rate=0.1, # These arguments will be passed to the model using **kwargs - layers=2, - pretrained_model=some_model - ) + Returns: + self_arguments: arguments dictionary of the first instance + parents_arguments: arguments dictionary of the parent's instances + """ + if not frame: + frame = inspect.currentframe() + + frame_args = collect_init_args(frame.f_back, []) + self_arguments = frame_args[-1] + + # set hyper_parameters in child + self_arguments = self_arguments + parents_arguments = {} + + # add all arguments from parents + for args in frame_args[:-1]: + parents_arguments.update(args) + return self_arguments, parents_arguments + + def save_hyperparameters( + self, + *args, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None + ) -> None: + """Save model arguments to ``hparams`` attribute. - # predict - pretrained_model.eval() - pretrained_model.freeze() - y_hat = pretrained_model(x) + Args: + args: single object of `dict`, `NameSpace` or `OmegaConf` + or string names or arguments from class ``__init__`` + ignore: an argument name or a list of argument names from + class ``__init__`` to be ignored + frame: a frame object. Default is None + + Example:: + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # manually assign arguments + ... self.save_hyperparameters('arg1', 'arg3') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + + >>> class AutomaticArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # equivalent automatic + ... self.save_hyperparameters() + ... def forward(self, *args, **kwargs): + ... ... + >>> model = AutomaticArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg2": abc + "arg3": 3.14 + + >>> class SingleArgModel(LightningModule): + ... def __init__(self, params): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(params) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.hparams + "p1": 1 + "p2": abc + "p3": 3.14 + + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # pass argument(s) to ignore as a string or in a list + ... self.save_hyperparameters(ignore='arg2') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 """ - if map_location is not None: - checkpoint = torch.load(checkpoint_path, map_location=map_location) + if not frame: + frame = inspect.currentframe().f_back + init_args = get_init_args(frame) + assert init_args, "failed to inspect the self init" + + if ignore is not None: + if isinstance(ignore, str): + ignore = [ignore] + if isinstance(ignore, (list, tuple)): + ignore = [arg for arg in ignore if isinstance(arg, str)] + init_args = {k: v for k, v in init_args.items() if k not in ignore} + + if not args: + # take all arguments + hp = init_args + self._hparams_name = "kwargs" if hp else None else: - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - - if tags_csv is not None: - # add the hparams from csv file to checkpoint - hparams = load_hparams_from_tags_csv(tags_csv) - hparams.__setattr__('on_gpu', False) - checkpoint['hparams'] = vars(hparams) - - model = cls._load_model_state(checkpoint, *args, **kwargs) - return model - - @classmethod - def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'LightningModule': - cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters - ckpt_hparams = checkpoint.get('hparams') - - if cls_takes_hparams: - if ckpt_hparams is not None: - is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace' - hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams + # take only listed arguments in `save_hparams` + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] + if len(isx_non_str) == 1: + hp = args[isx_non_str[0]] + cand_names = [k for k, v in init_args.items() if v == hp] + self._hparams_name = cand_names[0] if cand_names else None else: - rank_zero_warn( - f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__" - " contains argument 'hparams'. Will pass in an empty Namespace instead." - " Did you forget to store your model hyperparameters in self.hparams?" - ) - hparams = Namespace() - else: # The user's LightningModule does not define a hparams argument - if ckpt_hparams is None: - hparams = None - else: - raise MisconfigurationException( - f"Checkpoint contains hyperparameters but {cls.__name__}'s __init__ " - f"is missing the argument 'hparams'. Are you loading the correct checkpoint?" - ) - - # load the state_dict on the model automatically - if hparams: - kwargs.update(hparams=hparams) - model = cls(*args, **kwargs) - model.load_state_dict(checkpoint['state_dict']) - - # give model a chance to load something - model.on_load_checkpoint(checkpoint) - - return model - - def summarize(self, mode: str) -> None: - model_summary = ModelSummary(self, mode=mode) - log.info('\n' + model_summary.__str__()) + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} + self._hparams_name = "kwargs" + + # `hparams` are expected here + if hp: + self._set_hparams(hp) + # make deep copy so there is not other runtime changes reflected + self._hparams_initial = copy.deepcopy(self._hparams) + + def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: + if isinstance(hp, Namespace): + hp = vars(hp) + if isinstance(hp, dict): + hp = AttributeDict(hp) + elif isinstance(hp, PRIMITIVE_TYPES): + raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") + elif not isinstance(hp, ALLOWED_CONFIG_TYPES): + raise ValueError(f"Unsupported config type of {type(hp)}.") + + if isinstance(hp, dict) and isinstance(self.hparams, dict): + self.hparams.update(hp) + else: + self._hparams = hp + + @torch.no_grad() + def to_onnx( + self, + file_path: Union[str, Path], + input_sample: Optional[Any] = None, + **kwargs, + ): + """ + Saves the model in ONNX format - def freeze(self) -> None: - r""" - Freeze all params for inference. + Args: + file_path: The path of the file the onnx model should be saved to. + input_sample: An input for tracing. Default: None (Use self.example_input_array) + **kwargs: Will be passed to torch.onnx.export function. Example: - .. code-block:: python - - model = MyLightningModule(...) - model.freeze() - + >>> class SimpleModel(LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + + >>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: + ... model = SimpleModel() + ... input_sample = torch.randn((1, 64)) + ... model.to_onnx(tmpfile.name, input_sample, export_params=True) + ... os.path.isfile(tmpfile.name) + True """ - for param in self.parameters(): - param.requires_grad = False + mode = self.training - self.eval() + if input_sample is None: + if self.example_input_array is None: + raise ValueError( + "Could not export to ONNX since neither `input_sample` nor" + " `model.example_input_array` attribute is set." + ) + input_sample = self.example_input_array - def unfreeze(self) -> None: - """ - Unfreeze all parameters for training. + input_sample = self._apply_batch_transfer_handler(input_sample) - .. code-block:: python + if "example_outputs" not in kwargs: + self.eval() + kwargs["example_outputs"] = self(input_sample) - model = MyLightningModule(...) - model.unfreeze() + torch.onnx.export(self, input_sample, file_path, **kwargs) + self.train(mode) + @torch.no_grad() + def to_torchscript( + self, + file_path: Optional[Union[str, Path]] = None, + method: Optional[str] = 'script', + example_inputs: Optional[Any] = None, + **kwargs, + ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ - for param in self.parameters(): - param.requires_grad = True - - self.train() - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - r""" - Called by Lightning to restore your model. - If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. + By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. + If you want to use tracing, please provided the argument `method='trace'` and make sure that either the + example_inputs argument is provided, or the model has self.example_input_array set. + If you would like to customize the modules that are scripted you should override this method. + In case you want to return multiple modules, we recommend using a dictionary. Args: - checkpoint: Loaded checkpoint - - - Example: - .. code-block:: python - - def on_load_checkpoint(self, checkpoint): - # 99% of the time you don't need to implement this method - self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save'] + file_path: Path where to save the torchscript. Default: None (no file saved). + method: Whether to use TorchScript's script or trace method. Default: 'script' + example_inputs: An input to be used to do tracing when method is set to 'trace'. + Default: None (Use self.example_input_array) + **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or + :func:`torch.jit.trace` function. Note: - Lightning auto-restores global step, epoch, and train state including amp scaling. - There is no need for you to restore anything regarding training. - """ - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - r""" - Called by Lightning when saving a checkpoint to give you a chance to store anything - else you might want to save. - - Args: - checkpoint: Checkpoint to be saved + - Requires the implementation of the + :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. + - The exported script will be set to evaluation mode. + - It is recommended that you install the latest supported version of PyTorch + to use this feature without limitations. See also the :mod:`torch.jit` + documentation for supported features. Example: - .. code-block:: python - - - def on_save_checkpoint(self, checkpoint): - # 99% of use cases you don't need to implement this method - checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object - - Note: - Lightning saves all aspects of training (epoch, global step, etc...) - including amp scaling. - There is no need for you to store anything about training. - - """ - - def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: - r""" - Additional items to be displayed in the progress bar. + >>> class SimpleModel(LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + ... + >>> model = SimpleModel() + >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP + >>> os.path.isfile("model.pt") # doctest: +SKIP + >>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP + ... example_inputs=torch.randn(1, 64))) # doctest: +SKIP + >>> os.path.isfile("model_trace.pt") # doctest: +SKIP + True Return: - Dictionary with the items to be displayed in the progress bar. + This LightningModule as a torchscript, regardless of whether file_path is + defined or not. """ - # call .item() only once but store elements without graphs - running_train_loss = self.trainer.running_loss.mean() - avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN') - tqdm_dict = { - 'loss': '{:.3f}'.format(avg_training_loss) - } - - if self.trainer.truncated_bptt_steps is not None: - tqdm_dict['split_idx'] = self.trainer.split_idx - - if self.trainer.logger is not None and self.trainer.logger.version is not None: - tqdm_dict['v_num'] = self.trainer.logger.version - - return tqdm_dict - - def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: - """ - Additional items to be displayed in the progress bar. - - Return: - Dictionary with the items to be displayed in the progress bar. + mode = self.training + + if method == 'script': + torchscript_module = torch.jit.script(self.eval(), **kwargs) + elif method == 'trace': + # if no example inputs are provided, try to see if model has example_input_array set + if example_inputs is None: + if self.example_input_array is None: + raise ValueError( + 'Choosing method=`trace` requires either `example_inputs`' + ' or `model.example_input_array` to be defined.' + ) + example_inputs = self.example_input_array - Warning: - Deprecated since v0.7.3. - Use :meth:`get_progress_bar_dict` instead. - """ - rank_zero_warn("`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" - " and this method will be removed in v1.0.0", DeprecationWarning) - return self.get_progress_bar_dict() + # automatically send example inputs to the right device and use trace + example_inputs = self._apply_batch_transfer_handler(example_inputs) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + else: + raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}") + + self.train(mode) + + if file_path is not None: + torch.jit.save(torchscript_module, file_path) + + return torchscript_module + + @property + def hparams(self) -> Union[AttributeDict, dict, Namespace]: + if not hasattr(self, "_hparams"): + self._hparams = AttributeDict() + return self._hparams + + @property + def hparams_initial(self) -> AttributeDict: + if not hasattr(self, "_hparams_initial"): + return AttributeDict() + # prevent any change + return copy.deepcopy(self._hparams_initial) + + @property + def model_size(self) -> float: + # todo: think about better way without need to dump model to drive + tmp_name = f"{uuid.uuid4().hex}.pt" + torch.save(self.state_dict(), tmp_name) + size_mb = os.path.getsize(tmp_name) / 1e6 + os.remove(tmp_name) + return size_mb diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 8cbd281d6b0241..a3eab728f8ea80 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -1,167 +1,309 @@ -""" -Generates a summary of a model's layers and dimensionality -""" +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import gc import os +import shutil import subprocess -from subprocess import PIPE -from typing import Tuple, Dict, Union, List +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from torch.nn import Module +import torch.nn as nn +from torch.utils.hooks import RemovableHandle -import pytorch_lightning as pl +from pytorch_lightning.utilities import AMPType, DeviceType -from pytorch_lightning import _logger as log +PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] +UNKNOWN_SIZE = "?" -class ModelSummary(object): +class LayerSummary(object): + """ + Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + It collects the following information: + + - Type of the layer (e.g. Linear, BatchNorm1d, ...) + - Input shape + - Output shape + - Number of parameters + + The input and output shapes are only known after the example input array was + passed through the model. + + Example:: + + >>> model = torch.nn.Conv2d(3, 8, 3) + >>> summary = LayerSummary(model) + >>> summary.num_parameters + 224 + >>> summary.layer_type + 'Conv2d' + >>> output = model(torch.rand(1, 3, 5, 5)) + >>> summary.in_size + [1, 3, 5, 5] + >>> summary.out_size + [1, 8, 3, 3] + + Args: + module: A module to summarize - def __init__(self, model: 'pl.LightningModule', mode: str = 'full'): - """ Generates summaries of model layers and dimensions. """ - self.model = model - self.mode = mode - self.in_sizes = [] - self.out_sizes = [] + """ - self.summarize() + def __init__(self, module: nn.Module): + super().__init__() + self._module = module + self._hook_handle = self._register_hook() + self._in_size = None + self._out_size = None - def __str__(self): - return self.summary.__str__() + def __del__(self): + self.detach_hook() - def __repr__(self): - return self.summary.__str__() + def _register_hook(self) -> Optional[RemovableHandle]: + """ + Registers a hook on the module that computes the input- and output size(s) on the first forward pass. + If the hook is called, it will remove itself from the from the module, meaning that + recursive models will only record their input- and output shapes once. + Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. + + Return: + A handle for the installed hook, or ``None`` if registering the hook is not possible. + """ + + def hook(module, inp, out): + if len(inp) == 1: + inp = inp[0] + self._in_size = parse_batch_shape(inp) + self._out_size = parse_batch_shape(out) + self._hook_handle.remove() + + handle = None + if not isinstance(self._module, torch.jit.ScriptModule): + handle = self._module.register_forward_hook(hook) + return handle + + def detach_hook(self): + """ + Removes the forward hook if it was not already removed in the forward pass. + Will be called after the summary is created. + """ + if self._hook_handle is not None: + self._hook_handle.remove() - def named_modules(self) -> List[Tuple[str, Module]]: - if self.mode == 'full': - mods = self.model.named_modules() + @property + def in_size(self) -> Union[str, List]: + return self._in_size or UNKNOWN_SIZE + + @property + def out_size(self) -> Union[str, List]: + return self._out_size or UNKNOWN_SIZE + + @property + def layer_type(self) -> str: + """ Returns the class name of the module. """ + return str(self._module.__class__.__name__) + + @property + def num_parameters(self) -> int: + """ Returns the number of parameters in this module. """ + return sum(np.prod(p.shape) for p in self._module.parameters()) + + +class ModelSummary(object): + """ + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + + Args: + model: The model to summarize (also referred to as the root module) + mode: Can be one of + + - `top` (default): only the top-level modules will be recorded (the children of the root module) + - `full`: summarizes all layers and their submodules in the root module + + The string representation of this summary prints a table with columns containing + the name, type and number of parameters for each layer. + + The root module may also have an attribute ``example_input_array`` as shown in the example below. + If present, the root module will be called with it as input to determine the + intermediate input- and output shapes of all layers. Supported are tensors and + nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` + in the summary table. The summary will also display `?` for layers not used in the forward pass. + + Example:: + + >>> import pytorch_lightning as pl + >>> class LitModel(pl.LightningModule): + ... + ... def __init__(self): + ... super().__init__() + ... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512)) + ... self.example_input_array = torch.zeros(10, 256) # optional + ... + ... def forward(self, x): + ... return self.net(x) + ... + >>> model = LitModel() + >>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE + | Name | Type | Params | In sizes | Out sizes + ------------------------------------------------------------ + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + ------------------------------------------------------------ + 132 K Trainable params + 0 Non-trainable params + 132 K Total params + 0.530 Total estimated model params size (MB) + >>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE + | Name | Type | Params | In sizes | Out sizes + -------------------------------------------------------------- + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] + 2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512] + -------------------------------------------------------------- + 132 K Trainable params + 0 Non-trainable params + 132 K Total params + 0.530 Total estimated model params size (MB) + """ + + MODE_TOP = "top" + MODE_FULL = "full" + MODE_DEFAULT = MODE_TOP + MODES = [MODE_FULL, MODE_TOP] + + def __init__(self, model, mode: str = MODE_DEFAULT): + self._model = model + self._mode = mode + self._layer_summary = self.summarize() + # 1 byte -> 8 bits + # TODO: how do we compute precisin_megabytes in case of mixed precision? + precision = self._model.precision if isinstance(self._model.precision, int) else 32 + self._precision_megabytes = (precision / 8.0) * 1e-6 + + @property + def named_modules(self) -> List[Tuple[str, nn.Module]]: + if self._mode == ModelSummary.MODE_FULL: + mods = self._model.named_modules() mods = list(mods)[1:] # do not include root module (LightningModule) - elif self.mode == 'top': + elif self._mode == ModelSummary.MODE_TOP: # the children are the top-level modules - mods = self.model.named_children() + mods = self._model.named_children() else: mods = [] return list(mods) - def get_variable_sizes(self) -> None: - """ Run sample input through each layer to get output sizes. """ - mods = self.named_modules() - in_sizes = [] - out_sizes = [] - input_ = self.model.example_input_array + @property + def layer_names(self) -> List[str]: + return list(self._layer_summary.keys()) - if self.model.on_gpu: - device = next(self.model.parameters()).get_device() - # test if input is a list or a tuple - if isinstance(input_, (list, tuple)): - input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i - for input_i in input_] - else: - input_ = input_.cuda(device) + @property + def layer_types(self) -> List[str]: + return [layer.layer_type for layer in self._layer_summary.values()] - if self.model.trainer.use_amp: - # test if it is not a list or a tuple - if isinstance(input_, (list, tuple)): - input_ = [input_i.half() if torch.is_tensor(input_i) else input_i - for input_i in input_] - else: - input_ = input_.half() + @property + def in_sizes(self) -> List: + return [layer.in_size for layer in self._layer_summary.values()] + + @property + def out_sizes(self) -> List: + return [layer.out_size for layer in self._layer_summary.values()] + + @property + def param_nums(self) -> List[int]: + return [layer.num_parameters for layer in self._layer_summary.values()] + + @property + def total_parameters(self) -> int: + return sum(p.numel() for p in self._model.parameters()) + + @property + def trainable_parameters(self) -> int: + return sum(p.numel() for p in self._model.parameters() if p.requires_grad) + + @property + def model_size(self) -> float: + # todo: seems it does not work with quantized models - it returns 0.0 + return self.total_parameters * self._precision_megabytes + + def summarize(self) -> Dict[str, LayerSummary]: + summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) + if self._model.example_input_array is not None: + self._forward_example_input() + for layer in summary.values(): + layer.detach_hook() + return summary + + def _forward_example_input(self) -> None: + """ Run the example input through each layer to get input- and output sizes. """ + model = self._model + trainer = self._model.trainer + input_ = model.example_input_array + input_ = model._apply_batch_transfer_handler(input_, model.device) + + if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: + model.forward = torch.cuda.amp.autocast()(model.forward) + + mode = model.training + model.eval() with torch.no_grad(): + # let the model hooks collect the input- and output shapes + if isinstance(input_, (list, tuple)): + model(*input_) + elif isinstance(input_, dict): + model(**input_) + else: + model(input_) + model.train(mode) # restore mode of module - for _, m in mods: - if isinstance(input_, (list, tuple)): # pragma: no-cover - out = m(*input_) - else: - out = m(input_) - - if isinstance(input_, (list, tuple)): # pragma: no-cover - in_size = [] - for x in input_: - if isinstance(x, list): - in_size.append(len(x)) - else: - in_size.append(x.size()) - else: - in_size = np.array(input_.size()) - - in_sizes.append(in_size) - - if isinstance(out, (list, tuple)): # pragma: no-cover - out_size = np.asarray([x.size() for x in out]) - else: - out_size = np.array(out.size()) - - out_sizes.append(out_size) - input_ = out - - self.in_sizes = in_sizes - self.out_sizes = out_sizes - assert len(in_sizes) == len(out_sizes) - - def get_layer_names(self) -> None: - """ Collect Layer Names """ - mods = self.named_modules() - names = [] - layers = [] - for name, m in mods: - names += [name] - layers += [str(m.__class__)] - - layer_types = [x.split('.')[-1][:-2] for x in layers] - - self.layer_names = names - self.layer_types = layer_types - - def get_parameter_sizes(self) -> None: - """ Get sizes of all parameters in `model`. """ - mods = self.named_modules() - sizes = [] - for _, m in mods: - p = list(m.parameters()) - modsz = [np.array(param.size()) for param in p] - sizes.append(modsz) - - self.param_sizes = sizes - - def get_parameter_nums(self) -> None: - """ Get number of parameters in each layer. """ - param_nums = [] - for mod in self.param_sizes: - all_params = 0 - for p in mod: - all_params += np.prod(p) - param_nums.append(all_params) - self.param_nums = param_nums - - def make_summary(self) -> None: + def __str__(self): """ Makes a summary listing with: - Layer Name, Layer Type, Input Size, Output Size, Number of Parameters + Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size """ - arrays = [['Name', self.layer_names], - ['Type', self.layer_types], - ['Params', list(map(get_human_readable_count, self.param_nums))]] - if self.model.example_input_array is not None: - arrays.append(['In sizes', self.in_sizes]) - arrays.append(['Out sizes', self.out_sizes]) + arrays = [ + [" ", list(map(str, range(len(self._layer_summary))))], + ["Name", self.layer_names], + ["Type", self.layer_types], + ["Params", list(map(get_human_readable_count, self.param_nums))], + ] + if self._model.example_input_array is not None: + arrays.append(["In sizes", self.in_sizes]) + arrays.append(["Out sizes", self.out_sizes]) + total_parameters = self.total_parameters + trainable_parameters = self.trainable_parameters + model_size = self.model_size + + return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) + + def __repr__(self): + return str(self) - self.summary = _format_summary_table(*arrays) - def summarize(self) -> None: - self.get_layer_names() - self.get_parameter_sizes() - self.get_parameter_nums() +def parse_batch_shape(batch: Any) -> Union[str, List]: + if hasattr(batch, "shape"): + return list(batch.shape) - if self.model.example_input_array is not None: - self.get_variable_sizes() - self.make_summary() + if isinstance(batch, (list, tuple)): + shape = [parse_batch_shape(el) for el in batch] + return shape + return UNKNOWN_SIZE -def _format_summary_table(*cols) -> str: + +def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: """ Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -170,68 +312,39 @@ def _format_summary_table(*cols) -> str: n_rows = len(cols[0][1]) n_cols = 1 + len(cols) - # Layer counter - counter = list(map(str, list(range(n_rows)))) - counter_len = max([len(c) for c in counter]) - - # Get formatting length of each column - length = [] + # Get formatting width of each column + col_widths = [] for c in cols: - str_l = len(c[0]) # default length is header length - for a in c[1]: - if isinstance(a, np.ndarray): - array_string = '[' + ', '.join([str(j) for j in a]) + ']' - str_l = max(len(array_string), str_l) - else: - str_l = max(len(a), str_l) - length.append(str_l) + col_width = max(len(str(a)) for a in c[1]) if n_rows else 0 + col_width = max(col_width, len(c[0])) # minimum length is header length + col_widths.append(col_width) # Formatting - s = '{:<{}}' - full_length = sum(length) + 3 * n_cols - header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)] + s = "{:<{}}" + total_width = sum(col_widths) + 3 * n_cols + header = [s.format(c[0], l) for c, l in zip(cols, col_widths)] # Summary = header + divider + Rest of table - summary = ' | '.join(header) + '\n' + '-' * full_length + summary = " | ".join(header) + "\n" + "-" * total_width for i in range(n_rows): - line = s.format(counter[i], counter_len) - for c, l in zip(cols, length): - if isinstance(c[1][i], np.ndarray): - array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']' - line += ' | ' + array_string + ' ' * (l - len(array_string)) - else: - line += ' | ' + s.format(c[1][i], l) - summary += '\n' + line + line = [] + for c, l in zip(cols, col_widths): + line.append(s.format(str(c[1][i]), l)) + summary += "\n" + " | ".join(line) + summary += "\n" + "-" * total_width + + summary += "\n" + s.format(get_human_readable_count(trainable_parameters), 10) + summary += "Trainable params" + summary += "\n" + s.format(get_human_readable_count(total_parameters - trainable_parameters), 10) + summary += "Non-trainable params" + summary += "\n" + s.format(get_human_readable_count(total_parameters), 10) + summary += "Total params" + summary += "\n" + s.format(get_formatted_model_size(model_size), 10) + summary += "Total estimated model params size (MB)" return summary -def print_mem_stack() -> None: # pragma: no-cover - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): - log.info(type(obj), obj.size()) - except Exception: - pass - - -def count_mem_items() -> Tuple[int, int]: # pragma: no-cover - num_params = 0 - num_tensors = 0 - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): - obj_type = str(type(obj)) - if 'parameter' in obj_type: - num_params += 1 - else: - num_tensors += 1 - except Exception: - pass - - return num_params, num_tensors - - def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: """ Get a profile of the current memory usage. @@ -251,38 +364,42 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: """ memory_map = get_gpu_memory_map() - if mode == 'min_max': + if mode == "min_max": min_index, min_memory = min(memory_map.items(), key=lambda item: item[1]) max_index, max_memory = max(memory_map.items(), key=lambda item: item[1]) - memory_map = {'min_gpu_mem': min_memory, 'max_gpu_mem': max_memory} + memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory} return memory_map def get_gpu_memory_map() -> Dict[str, int]: - """Get the current gpu usage. + """ + Get the current gpu usage. Return: A dictionary in which the keys are device ids as integers and values are memory usage as integers in MB. """ result = subprocess.run( - [ - 'nvidia-smi', - '--query-gpu=memory.used', - '--format=csv,nounits,noheader', - ], - encoding='utf-8', + [shutil.which("nvidia-smi"), "--query-gpu=memory.used", "--format=csv,nounits,noheader"], + encoding="utf-8", # capture_output=True, # valid for python version >=3.7 - stdout=PIPE, stderr=PIPE, # for backward compatibility with python version 3.6 - check=True) + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + check=True, + ) + # Convert lines into a dictionary - gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)] - gpu_memory_map = {f'gpu_{index}': memory for index, memory in enumerate(gpu_memory)} + gpu_memory = [float(x) for x in result.stdout.strip().split(os.linesep)] + gpu_memory_map = {f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)} return gpu_memory_map +def get_formatted_model_size(total_model_size: float) -> float: + return f"{total_model_size:,.3f}" + + def get_human_readable_count(number: int) -> str: """ Abbreviates an integer number with K, M, B, T for thousands, millions, @@ -292,13 +409,13 @@ def get_human_readable_count(number: int) -> str: >>> get_human_readable_count(123) '123 ' >>> get_human_readable_count(1234) # (one thousand) - '1 K' + '1.2 K' >>> get_human_readable_count(2e6) # (two million) - '2 M' + '2.0 M' >>> get_human_readable_count(3e9) # (three billion) - '3 B' - >>> get_human_readable_count(4e12) # (four trillion) - '4 T' + '3.0 B' + >>> get_human_readable_count(4e14) # (four hundred trillion) + '400 T' >>> get_human_readable_count(5e15) # (more than trillion) '5,000 T' @@ -310,11 +427,14 @@ def get_human_readable_count(number: int) -> str: """ assert number >= 0 - labels = [' ', 'K', 'M', 'B', 'T'] + labels = PARAMETER_NUM_UNITS num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) num_groups = int(np.ceil(num_digits / 3)) num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions shift = -3 * (num_groups - 1) - number = number * (10 ** shift) + number = number * (10**shift) index = num_groups - 1 - return f'{int(number):,d} {labels[index]}' + if index < 1 or number >= 100: + return f"{int(number):,d} {labels[index]}" + + return f"{number:,.1f} {labels[index]}" diff --git a/pytorch_lightning/core/model_saving.py b/pytorch_lightning/core/model_saving.py deleted file mode 100644 index 8c3630237c1fc0..00000000000000 --- a/pytorch_lightning/core/model_saving.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `model_saving` module has been renamed to `saving` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`model_saving` module has been renamed to `saving` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.saving import * # noqa: F403 diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py new file mode 100644 index 00000000000000..162e17ca47bf5e --- /dev/null +++ b/pytorch_lightning/core/optimizer.py @@ -0,0 +1,220 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import types +from contextlib import contextmanager +from typing import Callable, Optional +from weakref import proxy + +from torch.optim import Optimizer + +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def is_lightning_optimizer(optimizer): + return isinstance(optimizer, LightningOptimizer) + + +def do_nothing_closure(): + return + + +class LightningOptimizer: + """ + This class is used to wrap the user optimizers and handle properly + the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches + """ + + def __init__(self, optimizer: Optimizer): + + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")} + + # For Horovod + if hasattr(optimizer, "skip_synchronize"): + self.__class__ = type( + "Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {} + ) + self.skip_synchronize = optimizer.skip_synchronize + self.synchronize = optimizer.synchronize + else: + self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) + + self._optimizer = optimizer + self._trainer = None + self._optimizer_idx = None + self._total_optimizer_step_calls = 0 + + @property + def optimizer(self): + return self._optimizer + + @property + def defaults(self): + return self._optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self._optimizer.defaults = defaults + + @property + def state(self): + return self._optimizer.state + + @state.setter + def state(self, state): + self._optimizer.state = state + + @property + def param_groups(self): + return self._optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._optimizer.param_groups = param_groups + + def _on_trainer_init(self, trainer): + self._trainer = proxy(trainer) + for opt_idx, opt in enumerate(trainer.optimizers): + if opt == self._optimizer: + self._optimizer_idx = opt_idx + break + + @classmethod + def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): + # apex overrides .step function and need to be wrapped on each step + if trainer.amp_backend == AMPType.APEX: + optimizer = cls(optimizer) + optimizer._on_trainer_init(trainer) + else: + optimizer = trainer.lightning_optimizers[opt_idx] + return optimizer + + def _toggle_model(self): + model_ref = self._trainer.lightning_module + model_ref.toggle_optimizer(self, self._optimizer_idx) + + def _untoggle_model(self): + model_ref = self._trainer.lightning_module + model_ref.untoggle_optimizer(self) + + @contextmanager + def toggle_model(self, sync_grad: bool = True): + """ + This function is just a helper for advanced users. + + Considering the current optimizer as A and all other optimizers as B. + Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. + + + When performing gradient accumulation, there is no need to perform grad synchronization + during the accumulation phase. + Setting `sync_grad` to False will block this synchronization and improve performance. + """ + with self._trainer.train_loop.block_ddp_sync_behaviour(not sync_grad): + self._toggle_model() + yield + self._untoggle_model() + + def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): + trainer = self._trainer + optimizer = self._optimizer + + with trainer.profiler.profile(profiler_name): + trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) + + def step(self, *args, closure: Optional[Callable] = None, **kwargs): + """ + Call this directly from your training_step when doing optimizations manually. + By using this we can ensure that all the proper scaling when using 16-bit, accelerator etc + is been done properly for you. + + .. note:: In Manual Optimization, the user is expected to know when to call zero_grad, + perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators + + Args: + + closure: One could provide its own optimizer_closure. Set to None by default. + + args: Any parameters provided to wrapped optimizer.step() + + kwargs: Any parameters provided to wrapped optimizer.step() + + Example:: + + # Scenario for a GAN. + + def training_step(...): + opt_gen, opt_dis = self.optimizers() + + ... + + # compute generator loss + loss_gen = self.compute_generator_loss(...) + # zero_grad needs to be called before backward + opt_gen.zero_grad() + self.manual_backward(loss_gen) + opt_gen.step() + + # compute discriminator loss + loss_dis = self.compute_discriminator_loss(...) + + # zero_grad needs to be called before backward + opt_dis.zero_grad() + self.manual_backward(loss_dis) + opt_dis.step() + + + # Scenario for a GAN advanced + + def training_step(self, batch, batch_idx, ...): + opt_gen, opt_dis = self.optimizers() + + ... + accumulated_grad_batches = batch_idx % 2 == 0 + + # compute generator loss + def closure_gen(): + loss_gen = self.compute_generator_loss(...) + self.manual_backward(loss_gen) + if accumulated_grad_batches: + opt_gen.zero_grad() + + with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): + opt_gen.step(closure=closure_gen) + + def closure_dis(): + loss_dis = self.compute_discriminator_loss(...) + self.manual_backward(loss_dis) + if accumulated_grad_batches: + opt_dis.zero_grad() + + with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): + opt_dis.step(closure=closure_dis) + + """ + if closure is None: + profiler_name = "closure_{self._optimizer_idx}" + closure = do_nothing_closure + else: + if not isinstance(closure, types.FunctionType): + raise MisconfigurationException("When closure is provided, it should be a function") + profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" + + self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) + self._total_optimizer_step_calls += 1 + + def __repr__(self): + groups = [{k: round(v, 12) if isinstance(v, float) else v + for k, v in sorted(group.items()) if k != "params"} for group in self.param_groups] + return f"{self.__class__.__name__}(groups={groups})" diff --git a/pytorch_lightning/core/root_module.py b/pytorch_lightning/core/root_module.py deleted file mode 100644 index b8e602da68a20f..00000000000000 --- a/pytorch_lightning/core/root_module.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module` module has been renamed to `lightning` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module` module has been renamed to `lightning` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.lightning import * # noqa: F403 diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index d5e8d2a6005277..280eca55260a7d 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -1,12 +1,210 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast import csv +import inspect +import logging import os from argparse import Namespace -from typing import Union, Dict, Any +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union +from warnings import warn + +import torch +import yaml + +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.parsing import parse_class_init_keys + +log = logging.getLogger(__name__) +PRIMITIVE_TYPES = (bool, int, float, str) +ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) + +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig + from omegaconf.errors import UnsupportedValueType, ValidationError -from pytorch_lightning import _logger as log +# the older shall be on the top +CHECKPOINT_PAST_HPARAMS_KEYS = ( + 'hparams', + 'module_arguments', # used in 0.7.6 +) class ModelIO(object): + CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters' + CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name' + CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type' + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: Union[str, IO], + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + r""" + Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint + it stores the arguments passed to `__init__` in the checkpoint under `hyper_parameters` + + Any arguments specified through \*args and \*\*kwargs will override args stored in `hyper_parameters`. + + Args: + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object + map_location: + If your checkpoint saved a GPU model and you now load on CPUs + or a different number of GPUs, use this to map to the new setup. + The behaviour is the same as in :func:`torch.load`. + hparams_file: Optional path to a .yaml file with hierarchical structure + as in this example:: + + drop_prob: 0.2 + dataloader: + batch_size: 32 + + You most likely won't need this since Lightning will always save the hyperparameters + to the checkpoint. + However, if your checkpoint weights don't have the hyperparameters saved, + use this method to pass in a .yaml file with the hparams you'd like to use. + These will be converted into a :class:`~dict` and passed into your + :class:`LightningModule` for use. + + If your model's `hparams` argument is :class:`~argparse.Namespace` + and .yaml file has hierarchical structure, you need to refactor your model to treat + `hparams` as :class:`~dict`. + strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys + returned by this module's state dict. Default: `True`. + kwargs: Any extra keyword args needed to init the model. Can also be used to override saved + hyperparameter values. + + Return: + :class:`LightningModule` with loaded weights and hyperparameters (if available). + + Example:: + + # load weights without mapping ... + MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') + + # or load weights mapping all weights from GPU 1 to GPU 0 ... + map_location = {'cuda:1':'cuda:0'} + MyLightningModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + map_location=map_location + ) + + # or load weights and hyperparameters from separate files. + MyLightningModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + hparams_file='/path/to/hparams_file.yaml' + ) + + # override some of the params with new values + MyLightningModule.load_from_checkpoint( + PATH, + num_layers=128, + pretrained_ckpt_path: NEW_PATH, + ) + + # predict + pretrained_model.eval() + pretrained_model.freeze() + y_hat = pretrained_model(x) + """ + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + if hparams_file is not None: + extension = hparams_file.split('.')[-1] + if extension.lower() in ('csv'): + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ('yml', 'yaml'): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError('.csv, .yml or .yaml is required for `hparams_file`') + + hparams['on_gpu'] = False + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) + + model = cls._load_model_state(checkpoint, strict=strict, **kwargs) + return model + + @classmethod + def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new): + cls_spec = inspect.getfullargspec(cls.__init__) + cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() + + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + drop_names = [n for n in (self_var, args_var, kwargs_var) if n] + cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) + + cls_kwargs_loaded = {} + # pass in the values we saved automatically + if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + + # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys + for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: + cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) + + # 2. Try to restore model hparams from checkpoint using the new key + _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY + cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) + + # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace + cls_kwargs_loaded = _convert_loaded_hparams( + cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE) + ) + + # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority + args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) + if args_name and args_name in cls_init_args_name: + cls_kwargs_loaded = {args_name: cls_kwargs_loaded} + + _cls_kwargs = {} + _cls_kwargs.update(cls_kwargs_loaded) + _cls_kwargs.update(cls_kwargs_new) + + if not cls_spec.varkw: + # filter kwargs according to class init unless it allows any argument via kwargs + _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} + + model = cls(**_cls_kwargs) + + # give model a chance to load something + model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict'], strict=strict) + + return model def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ @@ -48,30 +246,172 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: """ -def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: - if not os.path.isfile(tags_csv): - log.warning(f'Missing Tags: {tags_csv}.') - return Namespace() +def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object: + """Convert hparams according given type in callable or string (past) format.""" + # if not hparams type define + if not hparams_type: + return model_args + # if past checkpoint loaded, convert str to callable + if isinstance(hparams_type, str): + hparams_type = AttributeDict + # convert hparams + return hparams_type(model_args) + + +def update_hparams(hparams: dict, updates: dict) -> None: + """ + Overrides hparams with new values + + >>> hparams = {'c': 4} + >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1}) + >>> hparams['a']['b'], hparams['c'] + (2, 1) + >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7}) + >>> hparams['a']['b'], hparams['c'] + (4, 7) + + Args: + hparams: the original params and also target object + updates: new params to be used as update - with open(tags_csv) as f: - csv_reader = csv.reader(f, delimiter=',') + """ + for k, v in updates.items(): + # if missing, add the key + if k not in hparams: + hparams[k] = v + continue + + # recurse if dictionary + if isinstance(v, dict): + update_hparams(hparams[k], updates[k]) + else: + # update the value + hparams.update({k: v}) + + +def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: + """Load hparams from a file. + + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') + >>> path_csv = os.path.join('.', 'testing-hparams.csv') + >>> save_hparams_to_tags_csv(path_csv, hparams) + >>> hparams_new = load_hparams_from_tags_csv(path_csv) + >>> vars(hparams) == hparams_new + True + >>> os.remove(path_csv) + """ + fs = get_filesystem(tags_csv) + if not fs.exists(tags_csv): + rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning) + return {} + + with fs.open(tags_csv, "r", newline="") as fp: + csv_reader = csv.reader(fp, delimiter=",") tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} - ns = Namespace(**tags) - return ns + return tags -def convert(val: str) -> Union[int, float, bool, str]: - constructors = [int, float, str] - if isinstance(val, str): - if val.lower() == 'true': - return True - if val.lower() == 'false': - return False +def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: + fs = get_filesystem(tags_csv) + if not fs.isdir(os.path.dirname(tags_csv)): + raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") + + if isinstance(hparams, Namespace): + hparams = vars(hparams) + + with fs.open(tags_csv, "w", newline="") as fp: + fieldnames = ["key", "value"] + writer = csv.DictWriter(fp, fieldnames=fieldnames) + writer.writerow({"key": "key", "value": "value"}) + for k, v in hparams.items(): + writer.writerow({"key": k, "value": v}) - for c in constructors: + +def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]: + """Load hparams from a file. + + Args: + config_yaml: Path to config yaml file + use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, + the hparams will be converted to `DictConfig` if possible + + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') + >>> path_yaml = './testing-hparams.yaml' + >>> save_hparams_to_yaml(path_yaml, hparams) + >>> hparams_new = load_hparams_from_yaml(path_yaml) + >>> vars(hparams) == hparams_new + True + >>> os.remove(path_yaml) + """ + fs = get_filesystem(config_yaml) + if not fs.exists(config_yaml): + rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) + return {} + + with fs.open(config_yaml, "r") as fp: + hparams = yaml.load(fp, Loader=yaml.UnsafeLoader) + + if _OMEGACONF_AVAILABLE: + if use_omegaconf: + try: + return OmegaConf.create(hparams) + except (UnsupportedValueType, ValidationError): + pass + return hparams + + +def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: + """ + Args: + config_yaml: path to new YAML file + hparams: parameters to be saved + """ + fs = get_filesystem(config_yaml) + if not fs.isdir(os.path.dirname(config_yaml)): + raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") + + # convert Namespace or AD to dict + if isinstance(hparams, Namespace): + hparams = vars(hparams) + elif isinstance(hparams, AttributeDict): + hparams = dict(hparams) + + # saving with OmegaConf objects + if _OMEGACONF_AVAILABLE: + # deepcopy: hparams from user shouldn't be resolved + hparams = deepcopy(hparams) + to_container = partial(OmegaConf.to_container, resolve=True) + hparams = apply_to_collection(hparams, DictConfig, to_container) + with fs.open(config_yaml, "w", encoding="utf-8") as fp: + try: + OmegaConf.save(hparams, fp) + return + except (UnsupportedValueType, ValidationError): + pass + + if not isinstance(hparams, dict): + raise TypeError("hparams must be dictionary") + + hparams_allowed = {} + # drop paramaters which contain some strange datatypes as fsspec + for k, v in hparams.items(): try: - return c(val) - except ValueError: - pass - return val + yaml.dump(v) + except TypeError: + warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") + hparams[k] = type(v).__name__ + else: + hparams_allowed[k] = v + + # saving the standard way + with fs.open(config_yaml, "w", newline="") as fp: + yaml.dump(hparams_allowed, fp) + + +def convert(val: str) -> Union[int, float, bool, str]: + try: + return ast.literal_eval(val) + except (ValueError, SyntaxError) as err: + log.debug(err) + return val diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py new file mode 100644 index 00000000000000..1db9bd4927cea4 --- /dev/null +++ b/pytorch_lightning/core/step_result.py @@ -0,0 +1,703 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Result class for easier logging and epoch-wise reduction.""" + +import numbers +from copy import copy +from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torchmetrics import Metric + +from pytorch_lightning.utilities.distributed import sync_ddp_if_available + + +class Result(Dict): + + def __init__(self, minimize: Optional[Tensor] = None): + super().__init__() + + if minimize is not None: + err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end' + self._assert_grad_tensor_metric('minimize', minimize, err) + self.minimize = minimize + + self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}} + + def __getitem__(self, key: Union[str, Any]) -> Any: + try: + return super().__getitem__(key) + except KeyError: + return super().__getitem__(f'{key}_step') + + def __getattr__(self, key: str) -> Any: + try: + if key == 'batch_log_metrics': + return self.get_batch_log_metrics() + elif key == 'batch_pbar_metrics': + return self.get_batch_pbar_metrics() + elif key == 'epoch_log_metrics': + return self.get_epoch_log_metrics() + elif key == 'epoch_pbar_metrics': + return self.get_epoch_pbar_metrics() + else: + return self[key] + except KeyError: + return None + + def __setattr__(self, key: str, val: Union[Tensor, Any]): + # ensure tensors are detached + if isinstance(val, torch.Tensor) and key != 'minimize': + val = val.detach() + self[key] = val + + def __getstate__(self): + return self + + def __setstate__(self, d): + self.update(d) + + def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''): + if x is not None: + if not isinstance(x, Tensor): + raise TypeError(f'{name} must be a torch.Tensor') + + m = f'{name} must have a computational graph.' + + if additional_err: + m += f' {additional_err}' + assert x.grad_fn is not None, m + + def log( + self, + name: str, + value: Any, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + sync_fn: Callable = None, + dataloader_idx: Optional[int] = None, + device: torch.device = None, + ): + # no metrics should be logged with graphs + if not enable_graph and isinstance(value, torch.Tensor): + value = value.detach() + + # sync across workers when using distributed training + sync_fn = sync_fn or sync_ddp_if_available + if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): + is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() + # TODO: Find a way to make the reduction only once, so we don't need to clone. + if is_dist_initialized and isinstance(value, torch.Tensor): + value = value.clone() + else: + value = torch.tensor(value, device=device, dtype=torch.float) + value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) + + if isinstance(value, torch.Tensor) and value.device.type == "xla": + value = value.cpu() + + if 'meta' not in self: + self.__setitem__('meta', {}) + + # if user requests both step and epoch, then we split the metric in two automatically + # one will be logged per step. the other per epoch + was_forked = False + if on_step and on_epoch: + was_forked = True + + # set step version + step_name = f'{name}_step' + + self.__set_meta( + step_name, + value, + prog_bar, + logger, + on_step=True, + on_epoch=False, + reduce_fx=reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + forked=False, + dataloader_idx=dataloader_idx, + ) + + self.__setitem__(step_name, value) + + # set epoch version + epoch_name = f'{name}_epoch' + + self.__set_meta( + epoch_name, + value, + prog_bar, + logger, + on_step=False, + on_epoch=True, + reduce_fx=reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + forked=False, + dataloader_idx=dataloader_idx, + ) + self.__setitem__(epoch_name, value) + + # always log the original metric + self.__set_meta( + name, + value, + prog_bar, + logger, + on_step, + on_epoch, + reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + forked=was_forked, + dataloader_idx=dataloader_idx, + ) + + # set the value + self.__setitem__(name, value) + + def __set_meta( + self, + name: str, + value: Any, + prog_bar: bool, + logger: bool, + on_step: bool, + on_epoch: bool, + reduce_fx: Callable, + tbptt_pad_token: int, + tbptt_reduce_fx: Callable, + forked: bool, + dataloader_idx: Union[int, None], + ): + # set the meta for the item + meta_value = value + meta = dict( + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + value=meta_value, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + forked=forked, + dataloader_idx=dataloader_idx, + ) + + self['meta'][name] = meta + + # track whether any input requires reduction on epoch end + _internal = self['meta']['_internal'] + _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) + + def track_batch_size(self, batch): + batch_size = Result.extract_batch_size(batch) + Result.attach_batch_size(batch_size, self) + + @staticmethod + def extract_batch_size(batch): + try: + batch_size = Result.unpack_batch_size(batch) + except RecursionError: + batch_size = 1 + return batch_size + + @staticmethod + def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None: + if batch_size is not None: + meta = result['meta'] + meta['_internal']['batch_sizes'].append(batch_size) + + def get_batch_sizes(self): + meta = self['meta'] + return torch.tensor(meta['_internal']['batch_sizes']) + + def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str: + if dataloader_idx is not None and add_dataloader_idx: + return f"{k}/dataloader_idx_{dataloader_idx}" + return k + + def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict: + """ + Gets the metrics to log at the end of the batch step + + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + + if options['forked'] and not include_forked_originals: + continue + + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + + if options['logger'] and options['on_step']: + if isinstance(self[k], Metric) and self[k]._forward_cache is not None: + result[dl_key] = self[k]._forward_cache.detach() + else: + result[dl_key] = self[k] + + return result + + def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: + """ + Gets the metrics to log at the end of epoch + """ + result = {} + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + + if options['forked']: + continue + + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + + if options['logger'] and options['on_epoch']: + if isinstance(self[k], Metric): + result[dl_key] = self[k].compute().detach() + self[k].reset() + else: + result[dl_key] = self[k] + + if k in self and not options['on_epoch'] and isinstance(self[k], Metric): + # reset metric anyway so state does not accumulate + # NOTE: we must compute before reseting just in case the computed value is needed + # later (i.e. if the step metric gets visited first, and then the epoch metric) + self[k].compute() + self[k].reset() + + return result + + def get_epoch_pbar_metrics(self, add_dataloader_idx=False): + """ + Gets the metrics to log at the end of epoch + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + + if options['forked']: + continue + + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + + if options['prog_bar'] and options['on_epoch']: + if isinstance(self[k], Metric): + result[dl_key] = self[k].compute().detach() + self[k].reset() + else: + result[dl_key] = self[k] + + if k in self and not options['on_epoch'] and isinstance(self[k], Metric): + # reset metric anyway so state does not accumulate + # NOTE: we must compute before reseting just in case the computed value is needed + # later (i.e. if the step metric gets visited first, and then the epoch metric) + self[k].compute() + self[k].reset() + + return result + + def get_forked_metrics(self, add_dataloader_idx=False): + """ + Gets the metrics to log at the end of epoch + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + + if options['forked']: + if isinstance(self[k], Metric): + result[dl_key] = self[k].compute().detach() + self[k].reset() + else: + result[dl_key] = self[k] + + return result + + def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False): + """ + Gets the metrics to log at the end of the batch step + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + + if options['forked'] and not include_forked_originals: + continue + + dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) + + if options['prog_bar'] and options['on_step']: + if isinstance(self[k], Metric) and self[k]._forward_cache is not None: + result[dl_key] = self[k]._forward_cache + else: + result[dl_key] = self[k] + + return result + + def detach(self) -> 'Result': + for k, v in self.items(): + if isinstance(v, torch.Tensor): + self.__setitem__(k, v.detach()) + return self + + def to(self, *args, **kwargs) -> 'Result': + """Move all self attributes to the given device.""" + for k, v in self.items(): + if isinstance(v, torch.Tensor): + self.__setitem__(k, v.to(*args, **kwargs)) + return self + + def cpu(self) -> 'Result': + """Move all self attributes to CPU.""" + return self.to(torch.device("cpu")) + + def __repr__(self): + self_copy = self.copy() + + if 'meta' in self_copy: + del self_copy['meta'] + + return str(self_copy) + + def __str__(self): + copy = self.copy() + del copy['meta'] + + return str(copy) + + def __copy__(self): + newone = type(self)() + for k, v in self.items(): + if isinstance(v, torch.Tensor): + v = v.detach() + newone[k] = copy(v) + return newone + + @staticmethod + def unpack_batch_size(sample): + """ + Recursively unpack sample to find a torch.Tensor. + returns len(tensor) when found, or 1 when it hits an empty or non iterable. + """ + if isinstance(sample, torch.Tensor): + size = sample.size(0) + elif isinstance(sample, str): + return len(sample) + elif isinstance(sample, dict): + sample = next(iter(sample.values()), 1) + size = Result.unpack_batch_size(sample) + elif isinstance(sample, Iterable): + sample = next(iter(sample), 1) + size = Result.unpack_batch_size(sample) + else: + size = 1 + return size + + @classmethod + def gather(cls, outputs): + meta = outputs[0].get('meta') + result = cls() + result = recursive_gather(outputs, result) + recursive_stack(result) + + if meta: + result['meta'] = meta + return result + + @classmethod + def padded_gather(cls, outputs): + meta = outputs[0].get('meta') + result = cls() + result = recursive_gather(outputs, result) + + # find the padding used for other values + default_padding_idx = 0 + for name, value in result.items(): + if ( + name != 'minimize' and isinstance(value, list) and len(value) > 0 + and isinstance(value[0], torch.Tensor) + ): + default_padding_idx = meta[name]['tbptt_pad_token'] + break + + # pad across each key individually + for name, value in result.items(): + if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)): + padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token'] + padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key) + result[name] = padded + + # also update the result + if meta and name != "minimize": + meta[name]['value'] = padded + if meta: + result['meta'] = meta + return result + + @classmethod + def reduce_on_epoch_end(cls, outputs): + # get the batch sizes for all outputs + batch_sizes = [] + meta = {} + for x in outputs: + batch_sizes.append(x.get_batch_sizes()) + meta.update(x['meta']) + + batch_sizes = torch.stack(batch_sizes).view(-1) + + result = cls() + result = recursive_gather(outputs, result) + recursive_stack(result) + + for k, option in meta.items(): + if k == '_internal' or isinstance(result[k], Metric): + continue + + # for forked metrics don't reduce, just take the last val + if option['forked']: + result[k] = choose_last(result[k]) + continue + + if option['on_epoch']: + fx = option['reduce_fx'] + if fx == torch.mean: + if isinstance(result[k], list): + result[k] = torch.tensor(result[k]).float() + try: + reduced_val = weighted_mean(result[k], batch_sizes) + # todo: specify the expected Exceptions to come + except Exception: + reduced_val = torch.mean(result[k]) + else: + reduced_val = fx(result[k]) + + result[k] = reduced_val + else: + del result[k] + + result['meta'] = meta + return result + + @classmethod + def reduce_across_time(cls, time_outputs): + # auto-reduce across time for tbptt + meta = time_outputs[0]['meta'] + + # in 1.0 the results have 'extra'. Once we deprecate 0.10.0 we may not need this + if 'extra' in time_outputs[0]: + [x.pop('extra', None) for x in time_outputs] + + result = cls() + result = recursive_gather(time_outputs, result) + recursive_stack(result) + + for k, value in result.items(): + if k in ['meta', 'extra'] or isinstance(value, Metric): + continue + + # pick the reduce fx + tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx'] + + if isinstance(value, list): + value = torch.tensor(value) + + if isinstance(value, dict): + # TODO: recursive reduce: + _recursive_fx_apply(value, tbptt_reduce_fx) + else: + result[k] = tbptt_reduce_fx(value.float()) + + result['meta'] = meta + return result + + def dp_reduce(self): + for k, value in self.items(): + if k == 'meta' or isinstance(value, Metric): + continue + + if isinstance(value, list): + value = torch.tensor(value) + + self[k] = value.mean(dim=-1) + + @property + def should_reduce_on_epoch_end(self) -> bool: + return self['meta']['_internal']['_reduce_on_epoch'] + + def rename_keys(self, map_dict: dict): + """ + Maps key values to the target values. Useful when renaming variables in mass. + + Args: + map_dict: + """ + meta = self.meta + for source, dest in map_dict.items(): + # map the main keys + self[dest] = self[source] + del self[source] + + # map meta + meta[dest] = meta[source] + del meta[source] + + def get_non_metrics_keys(self): + """ + This function is used to filter metric keys for which the value isn't a Metric + """ + return [k for k, v in self.items() if not isinstance(v, Metric)] + + +def choose_last(x): + if isinstance(x, (torch.Tensor, list)): + return x[-1] + if isinstance(x, dict): + for k, v in x.items(): + x[k] = x[k][-1] + + +def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]: + for out in outputs: + if 'meta' in out: + del out['meta'] + + for k, v in out.items(): + # support manual opt where the user does not return a minimize key + if k == 'minimize' and v is None: + continue + + if isinstance(v, dict): + in_d = result.get(k, {}) + v = recursive_gather([v], in_d) + result[k] = v + else: + if isinstance(v, Metric): + # if v is a metric, just keep one of them, + # don't keep on adding a list of them + result[k] = v + else: + if k not in result: + result[k] = [] + result[k].append(v) + + return result + + +def recursive_stack(result: MutableMapping): + for k, v in result.items(): + if isinstance(v, dict): + recursive_stack(v) + + result[k] = collate_tensors(v) + + +def _recursive_fx_apply(input: dict, fx): + for k, v in input.items(): + if isinstance(v, list): + v = torch.tensor(v) + + if isinstance(v, torch.Tensor): + v = fx(v.float()) + input[k] = v + else: + _recursive_fx_apply(v, fx) + + +def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]: + if not items or not isinstance(items, (list, tuple)) or any(not isinstance(item, Tensor) for item in items): + # items is not a sequence, empty, or contains non-tensors + return items + + if all(item.ndim == 0 for item in items): + # all tensors are scalars, we need to stack + return torch.stack(items) + + if all(item.ndim >= 1 and item.shape[1:] == items[0].shape[1:] for item in items): + # we can concatenate along the first dimension + return torch.cat(items) + + return items + + +def weighted_mean(result, weights): + + if isinstance(result, dict): + _process_dataloader_aggregated_steps(result, weights) + else: + if isinstance(result, list): + result = torch.tensor(result) + + weights = weights.to(result.device)[:result.size(0)] + numerator = torch.dot(result.float(), weights.transpose(-1, 0).float()) + result = numerator / weights.sum().float() + return result + + +def _process_dataloader_aggregated_steps(result, weights): + internal_keys = {'meta'} + + moved = False + + for k, v in result.items(): + if k in internal_keys: + continue + + # make sure v is a tensor + if not isinstance(v, torch.Tensor): + v = torch.tensor(v) + + # move to memory only once + if not moved: + weights = weights.to(v.device) + moved = True + + # move weights to same device as value to reduce + weights_t = weights[:v.size(0)] + + # weighted mean + numerator = torch.dot(v.float(), weights_t.transpose(-1, 0).float()) + v = numerator / weights.sum().float() + result[k] = v diff --git a/pytorch_lightning/distributed/__init__.py b/pytorch_lightning/distributed/__init__.py new file mode 100644 index 00000000000000..ea060e551ad9da --- /dev/null +++ b/pytorch_lightning/distributed/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401 diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py new file mode 100644 index 00000000000000..37ac5d8b134629 --- /dev/null +++ b/pytorch_lightning/distributed/dist.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.utilities.distributed import group as _group + + +class LightningDistributed: + + def __init__(self, rank=None, device=None): + self.rank = rank + self.device = device + + def broadcast(self, obj: Any, group=_group.WORLD): + # always wrap into a list so list can be brodcasted. + obj = [obj] + + if self.rank != 0: + obj = [None] * len(obj) + + broadcast_object_list(obj, 0, group=group or _group.WORLD) + + return obj[0] diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py new file mode 100644 index 00000000000000..b00d1946424e7e --- /dev/null +++ b/pytorch_lightning/info.py @@ -0,0 +1,36 @@ +import time + +_this_year = time.strftime("%Y") +__version__ = '1.3.0dev' +__author__ = 'William Falcon et al.' +__author_email__ = 'waf2107@columbia.edu' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' +__docs_url__ = "https://pytorch-lightning.readthedocs.io/en/stable/" +# this has to be simple string, see: https://github.com/pypa/twine/issues/522 +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." +) +__long_docs__ = """ +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. + It's more of a style-guide than a framework. + +In Lightning, you organize your code into 3 distinct categories: + +1. Research code (goes in the LightningModule). +2. Engineering code (you delete, and is handled by the Trainer). +3. Non-essential research code (logging, etc. this goes in Callbacks). + +Although your research/production project might start simple, once you add things like GPU AND TPU training, + 16-bit precision, etc, you end up spending more time engineering than researching. + Lightning automates AND rigorously tests those parts for you. + +Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. + +Documentation +------------- +- https://pytorch-lightning.readthedocs.io/en/latest +- https://pytorch-lightning.readthedocs.io/en/stable +""" diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index c71a9aec0f99de..4d13d1842cd565 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -1,150 +1,48 @@ -""" -Lightning supports the most popular logging frameworks (TensorBoard, Comet, Weights and Biases, etc...). -To use a logger, simply pass it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`. -Lightning uses TensorBoard by default. - -.. code-block:: python - - from pytorch_lightning import Trainer - from pytorch_lightning import loggers - tb_logger = loggers.TensorBoardLogger('logs/') - trainer = Trainer(logger=tb_logger) - -Choose from any of the others such as MLflow, Comet, Neptune, WandB, ... - -.. code-block:: python - - comet_logger = loggers.CometLogger(save_dir='logs/') - trainer = Trainer(logger=comet_logger) - -To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers ... - -.. code-block:: python - - tb_logger = loggers.TensorBoardLogger('logs/') - comet_logger = loggers.CometLogger(save_dir='logs/') - trainer = Trainer(logger=[tb_logger, comet_logger]) - -Note: - All loggers log by default to ``os.getcwd()``. To change the path without creating a logger set - ``Trainer(default_root_dir='/your/path/to/save/checkpoints')`` - -Custom Logger -------------- - -You can implement your own logger by writing a class that inherits from -:class:`LightningLoggerBase`. Use the :func:`~pytorch_lightning.loggers.base.rank_zero_only` -decorator to make sure that only the first process in DDP training logs data. - -.. code-block:: python - - from pytorch_lightning.utilities import rank_zero_only - from pytorch_lightning.loggers import LightningLoggerBase - class MyLogger(LightningLoggerBase): - - @rank_zero_only - def log_hyperparams(self, params): - # params is an argparse.Namespace - # your code to record hyperparameters goes here - pass - - @rank_zero_only - def log_metrics(self, metrics, step): - # metrics is a dictionary of metric names and values - # your code to record metrics goes here - pass - - def save(self): - # Optional. Any code necessary to save logger data goes here - pass - - @rank_zero_only - def finalize(self, status): - # Optional. Any code that needs to be run after training - # finishes goes here - pass - -If you write a logger that may be useful to others, please send -a pull request to add it to Lighting! - -Using loggers -------------- - -Call the logger anywhere except ``__init__`` in your -:class:`~pytorch_lightning.core.lightning.LightningModule` by doing: - -.. code-block:: python - - from pytorch_lightning import LightningModule - class LitModel(LightningModule): - def training_step(self, batch, batch_idx): - # example - self.logger.experiment.whatever_method_summary_writer_supports(...) - - # example if logger is a tensorboard logger - self.logger.experiment.add_image('images', grid, 0) - self.logger.experiment.add_graph(model, images) - - def any_lightning_module_function_or_hook(self): - self.logger.experiment.add_histogram(...) - -Read more in the `Experiment Logging use case <./experiment_logging.html>`_. - -Supported Loggers ------------------ -""" +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from os import environ from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection +from pytorch_lightning.loggers.csv_logs import CSVLogger from pytorch_lightning.loggers.tensorboard import TensorBoardLogger __all__ = [ 'LightningLoggerBase', 'LoggerCollection', 'TensorBoardLogger', + 'CSVLogger', ] -try: - # needed to prevent ImportError and duplicated logs. - environ["COMET_DISABLE_AUTO_LOGGING"] = "1" +from pytorch_lightning.loggers.comet import _COMET_AVAILABLE, CometLogger # noqa: F401 +from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401 +from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE, NeptuneLogger # noqa: F401 +from pytorch_lightning.loggers.test_tube import _TESTTUBE_AVAILABLE, TestTubeLogger # noqa: F401 +from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE, WandbLogger # noqa: F401 - from pytorch_lightning.loggers.comet import CometLogger -except ImportError: # pragma: no-cover - del environ["COMET_DISABLE_AUTO_LOGGING"] # pragma: no-cover -else: +if _COMET_AVAILABLE: __all__.append('CometLogger') + # needed to prevent ImportError and duplicated logs. + environ["COMET_DISABLE_AUTO_LOGGING"] = "1" -try: - from pytorch_lightning.loggers.mlflow import MLFlowLogger -except ImportError: # pragma: no-cover - pass # pragma: no-cover -else: +if _MLFLOW_AVAILABLE: __all__.append('MLFlowLogger') -try: - from pytorch_lightning.loggers.neptune import NeptuneLogger -except ImportError: # pragma: no-cover - pass # pragma: no-cover -else: +if _NEPTUNE_AVAILABLE: __all__.append('NeptuneLogger') -try: - from pytorch_lightning.loggers.test_tube import TestTubeLogger -except ImportError: # pragma: no-cover - pass # pragma: no-cover -else: +if _TESTTUBE_AVAILABLE: __all__.append('TestTubeLogger') -try: - from pytorch_lightning.loggers.wandb import WandbLogger -except ImportError: # pragma: no-cover - pass # pragma: no-cover -else: +if _WANDB_AVAILABLE: __all__.append('WandbLogger') - -try: - from pytorch_lightning.loggers.trains import TrainsLogger -except ImportError: # pragma: no-cover - pass # pragma: no-cover -else: - __all__.append('TrainsLogger') diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 857d661fdb5b37..035a42338fe681 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -1,16 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base class used to build new loggers.""" + import argparse import functools import operator from abc import ABC, abstractmethod from argparse import Namespace -from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple +from functools import wraps +from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import numpy as np import torch +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only +def rank_zero_experiment(fn: Callable) -> Callable: + """ Returns the real experiment on rank 0 and otherwise the DummyExperiment. """ + + @wraps(fn) + def experiment(self): + + @rank_zero_only + def get_experiment(): + return fn(self) + + return get_experiment() or DummyExperiment() + + return experiment + + class LightningLoggerBase(ABC): """ Base class for experiment loggers. @@ -30,20 +62,19 @@ class LightningLoggerBase(ABC): """ def __init__( - self, - agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, - agg_default_func: Callable[[Sequence[float]], float] = np.mean + self, + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Callable[[Sequence[float]], float] = np.mean ): - self._rank = 0 self._prev_step: int = -1 self._metrics_to_agg: List[Dict[str, float]] = [] self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func def update_agg_funcs( - self, - agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, - agg_default_func: Callable[[Sequence[float]], float] = np.mean + self, + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Callable[[Sequence[float]], float] = np.mean ): """ Update aggregation methods. @@ -67,9 +98,9 @@ def update_agg_funcs( def experiment(self) -> Any: """Return the experiment object associated with this logger.""" - def _aggregate_metrics( - self, metrics: Dict[str, float], step: Optional[int] = None - ) -> Tuple[int, Optional[Dict[str, float]]]: + def _aggregate_metrics(self, + metrics: Dict[str, float], + step: Optional[int] = None) -> Tuple[int, Optional[Dict[str, float]]]: """ Aggregates metrics. @@ -154,7 +185,34 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: return params @staticmethod - def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]: + def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. + + Args: + params: Dictionary containing the hyperparameters + + Returns: + dictionary with all callables sanitized + """ + + def _sanitize_callable(val): + # Give them one chance to return a value. Don't go rabbit hole of recursive call + if isinstance(val, Callable): + try: + _val = val() + if isinstance(_val, Callable): + return val.__name__ + return _val + # todo: specify the possible exception + except Exception: + return getattr(val, "__name__", None) + return val + + return {key: _sanitize_callable(val) for key, val in params.items()} + + @staticmethod + def _flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]: """ Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. @@ -170,13 +228,16 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any {'a/b': 'c'} >>> LightningLoggerBase._flatten_dict({'a': {'b': 123}}) {'a/b': 123} + >>> LightningLoggerBase._flatten_dict({5: {'a': 123}}) + {'5/a': 123} """ def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] - if isinstance(input_dict, dict): + if isinstance(input_dict, MutableMapping): for key, value in input_dict.items(): - if isinstance(value, (dict, Namespace)): + key = str(key) + if isinstance(value, (MutableMapping, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): yield d @@ -209,16 +270,34 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: 'namespace': 'Namespace(foo=3)', 'string': 'abc'} """ - return {k: v if type(v) in [bool, int, float, str, torch.Tensor] else str(v) for k, v in params.items()} + for k in params.keys(): + # convert relevant np scalars to python types first (instead of str) + if isinstance(params[k], (np.bool_, np.integer, np.floating)): + params[k] = params[k].item() + elif type(params[k]) not in [bool, int, float, str, torch.Tensor]: + params[k] = str(params[k]) + return params @abstractmethod - def log_hyperparams(self, params: argparse.Namespace): + def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): """ Record hyperparameters. Args: params: :class:`~argparse.Namespace` containing the hyperparameters + args: Optional positional arguments, depends on the specific logger being used + kwargs: Optional keywoard arguments, depends on the specific logger being used + """ + + def log_graph(self, model: LightningModule, input_array=None) -> None: """ + Record model graph + + Args: + model: lightning model + input_array: input passes to `model.forward` + """ + pass def save(self) -> None: """Save log data.""" @@ -237,6 +316,14 @@ def close(self) -> None: """Do any cleanup that is necessary to close an experiment.""" self.save() + @property + def save_dir(self) -> Optional[str]: + """ + Return the root directory where experiment logs get saved, or `None` if the logger does not + save data locally. + """ + return None + @property @abstractmethod def name(self) -> str: @@ -247,6 +334,12 @@ def name(self) -> str: def version(self) -> Union[int, str]: """Return the experiment version.""" + def _add_prefix(self, metrics: Dict[str, float]): + if self._prefix: + metrics = {f'{self._prefix}{self.LOGGER_JOIN_CHAR}{k}': v for k, v in metrics.items()} + + return metrics + class LoggerCollection(LightningLoggerBase): """ @@ -264,24 +357,50 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] + def update_agg_funcs( + self, + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Callable[[Sequence[float]], float] = np.mean + ): + for logger in self._logger_iterable: + logger.update_agg_funcs(agg_key_funcs, agg_default_func) + @property def experiment(self) -> List[Any]: return [logger.experiment for logger in self._logger_iterable] + def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + for logger in self._logger_iterable: + logger.agg_and_log_metrics(metrics, step) + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - [logger.log_metrics(metrics, step) for logger in self._logger_iterable] + for logger in self._logger_iterable: + logger.log_metrics(metrics, step) def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - [logger.log_hyperparams(params) for logger in self._logger_iterable] + for logger in self._logger_iterable: + logger.log_hyperparams(params) + + def log_graph(self, model: LightningModule, input_array=None) -> None: + for logger in self._logger_iterable: + logger.log_graph(model, input_array) def save(self) -> None: - [logger.save() for logger in self._logger_iterable] + for logger in self._logger_iterable: + logger.save() def finalize(self, status: str) -> None: - [logger.finalize(status) for logger in self._logger_iterable] + for logger in self._logger_iterable: + logger.finalize(status) def close(self) -> None: - [logger.close() for logger in self._logger_iterable] + for logger in self._logger_iterable: + logger.close() + + @property + def save_dir(self) -> Optional[str]: + # Checkpoints should be saved to default / chosen location when using multiple loggers + return None @property def name(self) -> str: @@ -292,10 +411,57 @@ def version(self) -> str: return '_'.join([str(logger.version) for logger in self._logger_iterable]) +class DummyExperiment(object): + """ Dummy experiment """ + + def nop(*args, **kw): + pass + + def __getattr__(self, _): + return self.nop + + def __getitem__(self, idx) -> "DummyExperiment": + # enables self.logger.experiment[0].add_image(...) + return self + + +class DummyLogger(LightningLoggerBase): + """ + Dummy logger for internal use. It is useful if we want to disable user's + logger for a feature, but still ensure that user code can run + """ + + def __init__(self): + super().__init__() + self._experiment = DummyExperiment() + + @property + def experiment(self) -> DummyExperiment: + return self._experiment + + def log_metrics(self, *args, **kwargs) -> None: + pass + + def log_hyperparams(self, *args, **kwargs) -> None: + pass + + @property + def name(self) -> str: + return "" + + @property + def version(self) -> str: + return "" + + def __getitem__(self, idx) -> "DummyLogger": + # enables self.logger[0].experiment.add_image(...) + return self + + def merge_dicts( - dicts: Sequence[Mapping], - agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, - default_func: Callable[[Sequence[float]], float] = np.mean + dicts: Sequence[Mapping], + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + default_func: Callable[[Sequence[float]], float] = np.mean ) -> Dict: """ Merge a sequence with dictionaries into one dictionary by aggregating the diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 1d24676d8018a1..148e512f5e4392 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -1,37 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Comet ------ +Comet Logger +------------ """ +import logging +import os from argparse import Namespace -from typing import Optional, Dict, Union, Any +from typing import Any, Dict, Optional, Union -try: - from comet_ml import Experiment as CometExperiment +import torch +from torch import is_tensor + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities import _module_available, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +log = logging.getLogger(__name__) +_COMET_AVAILABLE = _module_available("comet_ml") + +if _COMET_AVAILABLE: + import comet_ml from comet_ml import ExistingExperiment as CometExistingExperiment + from comet_ml import Experiment as CometExperiment from comet_ml import OfflineExperiment as CometOfflineExperiment - from comet_ml import BaseExperiment as CometBaseExperiment + try: from comet_ml.api import API except ImportError: # pragma: no-cover # For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300 from comet_ml.papi import API # pragma: no-cover -except ImportError: # pragma: no-cover - raise ImportError('You want to use `comet_ml` logger which is not installed yet,' # pragma: no-cover - ' install it with `pip install comet-ml`.') - -import torch -from torch import is_tensor - -from pytorch_lightning import _logger as log -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_only +else: + # needed for test mocks, these tests shall be updated + comet_ml = None + CometExperiment, CometExistingExperiment, CometOfflineExperiment = None, None, None + API = None class CometLogger(LightningLoggerBase): r""" - Log using `Comet.ml `_. Install it with pip: + Log using `Comet.ml `_. + + Install it with pip: .. code-block:: bash @@ -41,39 +64,44 @@ class CometLogger(LightningLoggerBase): **ONLINE MODE** - Example: - >>> import os - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.loggers import CometLogger - >>> # arguments made to CometLogger are passed on to the comet_ml.Experiment class - >>> comet_logger = CometLogger( - ... api_key=os.environ.get('COMET_API_KEY'), - ... workspace=os.environ.get('COMET_WORKSPACE'), # Optional - ... save_dir='.', # Optional - ... project_name='default_project', # Optional - ... rest_api_key=os.environ.get('COMET_REST_API_KEY'), # Optional - ... experiment_name='default' # Optional - ... ) - >>> trainer = Trainer(logger=comet_logger) + .. code-block:: python + + import os + from pytorch_lightning import Trainer + from pytorch_lightning.loggers import CometLogger + # arguments made to CometLogger are passed on to the comet_ml.Experiment class + comet_logger = CometLogger( + api_key=os.environ.get('COMET_API_KEY'), + workspace=os.environ.get('COMET_WORKSPACE'), # Optional + save_dir='.', # Optional + project_name='default_project', # Optional + rest_api_key=os.environ.get('COMET_REST_API_KEY'), # Optional + experiment_key=os.environ.get('COMET_EXPERIMENT_KEY'), # Optional + experiment_name='default' # Optional + ) + trainer = Trainer(logger=comet_logger) **OFFLINE MODE** - Example: - >>> from pytorch_lightning.loggers import CometLogger - >>> # arguments made to CometLogger are passed on to the comet_ml.Experiment class - >>> comet_logger = CometLogger( - ... save_dir='.', - ... workspace=os.environ.get('COMET_WORKSPACE'), # Optional - ... project_name='default_project', # Optional - ... rest_api_key=os.environ.get('COMET_REST_API_KEY'), # Optional - ... experiment_name='default' # Optional - ... ) - >>> trainer = Trainer(logger=comet_logger) + .. code-block:: python + + from pytorch_lightning.loggers import CometLogger + # arguments made to CometLogger are passed on to the comet_ml.Experiment class + comet_logger = CometLogger( + save_dir='.', + workspace=os.environ.get('COMET_WORKSPACE'), # Optional + project_name='default_project', # Optional + rest_api_key=os.environ.get('COMET_REST_API_KEY'), # Optional + experiment_name='default' # Optional + ) + trainer = Trainer(logger=comet_logger) Args: - api_key: Required in online mode. API key, found on Comet.ml - save_dir: Required in offline mode. The path for the directory to save local comet logs - workspace: Optional. Name of workspace for this user + api_key: Required in online mode. API key, found on Comet.ml. If not given, this + will be loaded from the environment variable COMET_API_KEY or ~/.comet.config + if either exists. + save_dir: Required in offline mode. The path for the directory to save local + comet logs. If given, this also sets the directory for saving checkpoints. project_name: Optional. Send your experiment to a specific project. Otherwise will be sent to Uncategorized Experiments. If the project name does not already exist, Comet.ml will create a new project. @@ -81,38 +109,69 @@ class CometLogger(LightningLoggerBase): This is used to determine version number experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. experiment_key: Optional. If set, restores from existing experiment. + offline: If api_key and save_dir are both given, this determines whether + the experiment will be in online or offline mode. This is useful if you use + save_dir to control the checkpoints directory and have a ~/.comet.config + file but still want to run offline experiments. + prefix: A string to put at the beginning of metric keys. + \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + :class:`CometExperiment` can be passed as keyword arguments in this logger. + + Raises: + ImportError: + If required Comet package is not installed on the device. + MisconfigurationException: + If neither ``api_key`` nor ``save_dir`` are passed as arguments. """ - def __init__(self, - api_key: Optional[str] = None, - save_dir: Optional[str] = None, - workspace: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, - experiment_key: Optional[str] = None, - **kwargs): - + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + api_key: Optional[str] = None, + save_dir: Optional[str] = None, + project_name: Optional[str] = None, + rest_api_key: Optional[str] = None, + experiment_name: Optional[str] = None, + experiment_key: Optional[str] = None, + offline: bool = False, + prefix: str = '', + **kwargs + ): + if comet_ml is None: + raise ImportError( + "You want to use `comet_ml` logger which is not installed yet," + " install it with `pip install comet-ml`." + ) super().__init__() self._experiment = None # Determine online or offline mode based on which arguments were passed to CometLogger - if api_key is not None: + api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) + + if api_key is not None and save_dir is not None: + self.mode = "offline" if offline else "online" + self.api_key = api_key + self._save_dir = save_dir + elif api_key is not None: self.mode = "online" self.api_key = api_key + self._save_dir = None elif save_dir is not None: self.mode = "offline" - self.save_dir = save_dir + self._save_dir = save_dir else: # If neither api_key nor save_dir are passed as arguments, raise an exception raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") log.info(f"CometLogger will be initialized in {self.mode} mode") - self.workspace = workspace - self.project_name = project_name - self.experiment_key = experiment_key + self._project_name = project_name + self._experiment_key = experiment_key + self._experiment_name = experiment_name + self._prefix = prefix self._kwargs = kwargs + self._future_experiment_key = None if rest_api_key is not None: # Comet.ml rest API, used to determine version number @@ -122,15 +181,11 @@ def __init__(self, self.rest_api_key = None self.comet_api = None - if experiment_name: - try: - self.name = experiment_name - except TypeError as e: - log.exception("Failed to set experiment name for comet.ml logger") self._kwargs = kwargs @property - def experiment(self) -> CometBaseExperiment: + @rank_zero_experiment + def experiment(self): r""" Actual Comet object. To use Comet features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. @@ -143,30 +198,38 @@ def experiment(self) -> CometBaseExperiment: if self._experiment is not None: return self._experiment - if self.mode == "online": - if self.experiment_key is None: - self._experiment = CometExperiment( - api_key=self.api_key, - workspace=self.workspace, - project_name=self.project_name, - **self._kwargs - ) - self.experiment_key = self._experiment.get_key() + if self._future_experiment_key is not None: + os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key + + try: + if self.mode == "online": + if self._experiment_key is None: + self._experiment = CometExperiment( + api_key=self.api_key, + project_name=self._project_name, + **self._kwargs, + ) + self._experiment_key = self._experiment.get_key() + else: + self._experiment = CometExistingExperiment( + api_key=self.api_key, + project_name=self._project_name, + previous_experiment=self._experiment_key, + **self._kwargs, + ) else: - self._experiment = CometExistingExperiment( - api_key=self.api_key, - workspace=self.workspace, - project_name=self.project_name, - previous_experiment=self.experiment_key, - **self._kwargs + self._experiment = CometOfflineExperiment( + offline_directory=self.save_dir, + project_name=self._project_name, + **self._kwargs, ) - else: - self._experiment = CometOfflineExperiment( - offline_directory=self.save_dir, - workspace=self.workspace, - project_name=self.project_name, - **self._kwargs - ) + finally: + if self._future_experiment_key is not None: + os.environ.pop("COMET_EXPERIMENT_KEY") + self._future_experiment_key = None + + if self._experiment_name: + self._experiment.set_name(self._experiment_name) return self._experiment @@ -177,17 +240,17 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.log_parameters(params) @rank_zero_only - def log_metrics( - self, - metrics: Dict[str, Union[torch.Tensor, float]], - step: Optional[int] = None - ) -> None: + def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None: + assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" # Comet.ml expects metrics to be a dictionary of detached tensors on CPU for key, val in metrics.items(): if is_tensor(val): metrics[key] = val.cpu().detach() - self.experiment.log_metrics(metrics, step=step) + metrics_without_epoch = metrics.copy() + epoch = metrics_without_epoch.pop('epoch', None) + metrics_without_epoch = self._add_prefix(metrics_without_epoch) + self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) def reset_experiment(self): self._experiment = None @@ -206,14 +269,55 @@ def finalize(self, status: str) -> None: self.experiment.end() self.reset_experiment() + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property def name(self) -> str: - return str(self.experiment.project_name) + # Don't create an experiment if we don't have one + if self._experiment is not None and self._experiment.project_name is not None: + return self._experiment.project_name + + if self._project_name is not None: + return self._project_name - @name.setter - def name(self, value: str) -> None: - self.experiment.set_name(value) + return "comet-default" @property def version(self) -> str: - return self.experiment.id + # Don't create an experiment if we don't have one + if self._experiment is not None: + return self._experiment.id + + if self._experiment_key is not None: + return self._experiment_key + + if "COMET_EXPERIMENT_KEY" in os.environ: + return os.environ["COMET_EXPERIMENT_KEY"] + + if self._future_experiment_key is not None: + return self._future_experiment_key + + # Pre-generate an experiment key + self._future_experiment_key = comet_ml.generate_guid() + + return self._future_experiment_key + + def __getstate__(self): + state = self.__dict__.copy() + + # Save the experiment id in case an experiment object already exists, + # this way we could create an ExistingExperiment pointing to the same + # experiment + state["_experiment_key"] = self._experiment.id if self._experiment is not None else None + + # Remove the experiment object as it contains hard to pickle objects + # (like network connections), the experiment object will be recreated if + # needed later + state["_experiment"] = None + return state + + def log_graph(self, model: LightningModule, input_array=None) -> None: + if self._experiment is not None: + self._experiment.set_model_graph(model) diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py new file mode 100644 index 00000000000000..4df672fa6e3b53 --- /dev/null +++ b/pytorch_lightning/loggers/csv_logs.py @@ -0,0 +1,230 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CSV logger +---------- + +CSV logger for basic experiment logging that does not require opening ports + +""" +import csv +import io +import logging +import os +from argparse import Namespace +from typing import Any, Dict, Optional, Union + +import torch + +from pytorch_lightning.core.saving import save_hparams_to_yaml +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn + +log = logging.getLogger(__name__) + + +class ExperimentWriter(object): + r""" + Experiment writer for CSVLogger. + + Currently supports to log hyperparameters and metrics in YAML and CSV + format, respectively. + + Args: + log_dir: Directory for the experiment logs + """ + + NAME_HPARAMS_FILE = 'hparams.yaml' + NAME_METRICS_FILE = 'metrics.csv' + + def __init__(self, log_dir: str) -> None: + self.hparams = {} + self.metrics = [] + + self.log_dir = log_dir + if os.path.exists(self.log_dir) and os.listdir(self.log_dir): + rank_zero_warn( + f"Experiment logs directory {self.log_dir} exists and is not empty." + " Previous log files in this directory will be deleted when the new ones are saved!" + ) + os.makedirs(self.log_dir, exist_ok=True) + + self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) + + def log_hparams(self, params: Dict[str, Any]) -> None: + """Record hparams""" + self.hparams.update(params) + + def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: + """Record metrics""" + + def _handle_value(value): + if isinstance(value, torch.Tensor): + return value.item() + return value + + if step is None: + step = len(self.metrics) + + metrics = {k: _handle_value(v) for k, v in metrics_dict.items()} + metrics['step'] = step + self.metrics.append(metrics) + + def save(self) -> None: + """Save recorded hparams and metrics into files""" + hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) + save_hparams_to_yaml(hparams_file, self.hparams) + + if not self.metrics: + return + + last_m = {} + for m in self.metrics: + last_m.update(m) + metrics_keys = list(last_m.keys()) + + with io.open(self.metrics_file_path, 'w', newline='') as f: + self.writer = csv.DictWriter(f, fieldnames=metrics_keys) + self.writer.writeheader() + self.writer.writerows(self.metrics) + + +class CSVLogger(LightningLoggerBase): + r""" + Log to local file system in yaml and CSV format. + + Logs are saved to ``os.path.join(save_dir, name, version)``. + + Example: + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.loggers import CSVLogger + >>> logger = CSVLogger("logs", name="my_exp_name") + >>> trainer = Trainer(logger=logger) + + Args: + save_dir: Save directory + name: Experiment name. Defaults to ``'default'``. + version: Experiment version. If version is not specified the logger inspects the save + directory for existing versions, then automatically assigns the next available version. + prefix: A string to put at the beginning of metric keys. + """ + + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + save_dir: str, + name: Optional[str] = "default", + version: Optional[Union[int, str]] = None, + prefix: str = '', + ): + super().__init__() + self._save_dir = save_dir + self._name = name or '' + self._version = version + self._prefix = prefix + self._experiment = None + + @property + def root_dir(self) -> str: + """ + Parent directory for all checkpoint subdirectories. + If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used + and the checkpoint will be saved in "save_dir/version_dir" + """ + if not self.name: + return self.save_dir + return os.path.join(self.save_dir, self.name) + + @property + def log_dir(self) -> str: + """ + The log directory for this run. By default, it is named + ``'version_${self.version}'`` but it can be overridden by passing a string value + for the constructor's version parameter instead of ``None`` or an int. + """ + # create a pseudo standard path ala test-tube + version = self.version if isinstance(self.version, str) else f"version_{self.version}" + log_dir = os.path.join(self.root_dir, version) + return log_dir + + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + + @property + @rank_zero_experiment + def experiment(self) -> ExperimentWriter: + r""" + + Actual ExperimentWriter object. To use ExperimentWriter features in your + :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. + + Example:: + + self.logger.experiment.some_experiment_writer_function() + + """ + if self._experiment: + return self._experiment + + os.makedirs(self.root_dir, exist_ok=True) + self._experiment = ExperimentWriter(log_dir=self.log_dir) + return self._experiment + + @rank_zero_only + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + params = self._convert_params(params) + self.experiment.log_hparams(params) + + @rank_zero_only + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + metrics = self._add_prefix(metrics) + self.experiment.log_metrics(metrics, step) + + @rank_zero_only + def save(self) -> None: + super().save() + self.experiment.save() + + @rank_zero_only + def finalize(self, status: str) -> None: + self.save() + + @property + def name(self) -> str: + return self._name + + @property + def version(self) -> int: + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + root_dir = os.path.join(self._save_dir, self.name) + + if not os.path.isdir(root_dir): + log.warning('Missing logger folder: %s', root_dir) + return 0 + + existing_versions = [] + for d in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + existing_versions.append(int(d.split("_")[1])) + + if len(existing_versions) == 0: + return 0 + + return max(existing_versions) + 1 diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 74c51893e71e7e..88bed79904cf39 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -1,76 +1,126 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -MLflow ------- +MLflow Logger +------------- """ -import os +import logging +import re from argparse import Namespace from time import time -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn + +log = logging.getLogger(__name__) +LOCAL_FILE_URI_PREFIX = "file:" +_MLFLOW_AVAILABLE = _module_available("mlflow") try: import mlflow from mlflow.tracking import MlflowClient -except ImportError: # pragma: no-cover - raise ImportError('You want to use `mlflow` logger which is not installed yet,' # pragma: no-cover - ' install it with `pip install mlflow`.') - -from pytorch_lightning import _logger as log -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only +# todo: there seems to be still some remaining import error with Conda env +except ImportError: + _MLFLOW_AVAILABLE = False + mlflow, MlflowClient = None, None class MLFlowLogger(LightningLoggerBase): """ - Log using `MLflow `_. Install it with pip: + Log using `MLflow `_. + + Install it with pip: .. code-block:: bash pip install mlflow - Example: - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.loggers import MLFlowLogger - >>> mlf_logger = MLFlowLogger( - ... experiment_name="default", - ... tracking_uri="file:./ml-runs" - ... ) - >>> trainer = Trainer(logger=mlf_logger) - - Use the logger anywhere in you :class:`~pytorch_lightning.core.lightning.LightningModule` as follows: - - >>> from pytorch_lightning import LightningModule - >>> class LitModel(LightningModule): - ... def training_step(self, batch, batch_idx): - ... # example - ... self.logger.experiment.whatever_ml_flow_supports(...) - ... - ... def any_lightning_module_function_or_hook(self): - ... self.logger.experiment.whatever_ml_flow_supports(...) + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.loggers import MLFlowLogger + mlf_logger = MLFlowLogger( + experiment_name="default", + tracking_uri="file:./ml-runs" + ) + trainer = Trainer(logger=mlf_logger) + + Use the logger anywhere in your :class:`~pytorch_lightning.core.lightning.LightningModule` as follows: + + .. code-block:: python + + from pytorch_lightning import LightningModule + class LitModel(LightningModule): + def training_step(self, batch, batch_idx): + # example + self.logger.experiment.whatever_ml_flow_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.whatever_ml_flow_supports(...) Args: experiment_name: The name of the experiment tracking_uri: Address of local or remote tracking server. - If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri``. + If not provided, defaults to `file:`. tags: A dictionary tags for the experiment. - + save_dir: A path to a local directory where the MLflow runs get saved. + Defaults to `./mlflow` if `tracking_uri` is not provided. + Has no effect if `tracking_uri` is provided. + prefix: A string to put at the beginning of metric keys. + artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate + default. + + Raises: + ImportError: + If required MLFlow package is not installed on the device. """ - def __init__(self, - experiment_name: str = 'default', - tracking_uri: Optional[str] = None, - tags: Optional[Dict[str, Any]] = None, - save_dir: Optional[str] = None): + + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + experiment_name: str = 'default', + tracking_uri: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + save_dir: Optional[str] = './mlruns', + prefix: str = '', + artifact_location: Optional[str] = None, + ): + if mlflow is None: + raise ImportError( + 'You want to use `mlflow` logger which is not installed yet,' + ' install it with `pip install mlflow`.' + ) super().__init__() - if not tracking_uri and save_dir: - tracking_uri = f'file:{os.sep * 2}{save_dir}' - self._mlflow_client = MlflowClient(tracking_uri) - self.experiment_name = experiment_name + if not tracking_uri: + tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}' + + self._experiment_name = experiment_name + self._experiment_id = None + self._tracking_uri = tracking_uri self._run_id = None self.tags = tags + self._prefix = prefix + self._artifact_location = artifact_location + + self._mlflow_client = MlflowClient(tracking_uri) @property + @rank_zero_experiment def experiment(self) -> MlflowClient: r""" - Actual MLflow object. To use mlflow features in your + Actual MLflow object. To use MLflow features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: @@ -78,52 +128,92 @@ def experiment(self) -> MlflowClient: self.logger.experiment.some_mlflow_function() """ + if self._experiment_id is None: + expt = self._mlflow_client.get_experiment_by_name(self._experiment_name) + if expt is not None: + self._experiment_id = expt.experiment_id + else: + log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.') + self._experiment_id = self._mlflow_client.create_experiment( + name=self._experiment_name, + artifact_location=self._artifact_location, + ) + + if self._run_id is None: + run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) + self._run_id = run.info.run_id return self._mlflow_client @property def run_id(self): - if self._run_id is not None: - return self._run_id - - expt = self._mlflow_client.get_experiment_by_name(self.experiment_name) - - if expt: - self._expt_id = expt.experiment_id - else: - log.warning(f'Experiment with name {self.experiment_name} not found. Creating it.') - self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name) - - run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags) - self._run_id = run.info.run_id + # create the experiment if it does not exist to get the run id + _ = self.experiment return self._run_id + @property + def experiment_id(self): + # create the experiment if it does not exist to get the experiment id + _ = self.experiment + return self._experiment_id + @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) params = self._flatten_dict(params) for k, v in params.items(): + if len(str(v)) > 250: + rank_zero_warn( + f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", RuntimeWarning + ) + continue + self.experiment.log_param(self.run_id, k, v) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' + + metrics = self._add_prefix(metrics) + timestamp_ms = int(time() * 1000) for k, v in metrics.items(): if isinstance(v, str): log.warning(f'Discarding metric with string value {k}={v}.') continue + + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) + if k != new_k: + rank_zero_warn( + "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name." + f" Replacing {k} with {new_k}.", RuntimeWarning + ) + k = new_k + self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) @rank_zero_only def finalize(self, status: str = 'FINISHED') -> None: super().finalize(status) - if status == 'success': - status = 'FINISHED' - self.experiment.set_terminated(self.run_id, status) + status = 'FINISHED' if status == 'success' else status + if self.experiment.get_run(self.run_id): + self.experiment.set_terminated(self.run_id, status) + + @property + def save_dir(self) -> Optional[str]: + """ + The root file directory in which MLflow experiments are saved. + + Return: + Local path to the root experiment directory if the tracking uri is local. + Otherwhise returns `None`. + """ + if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): + return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX) @property def name(self) -> str: - return self.experiment_name + return self.experiment_id @property def version(self) -> str: - return self._run_id + return self.run_id diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 374b513b618e7b..a9209f1cbdd7be 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -1,89 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Neptune -------- +Neptune Logger +-------------- """ +import logging from argparse import Namespace -from typing import Optional, List, Dict, Any, Union, Iterable - -from PIL.Image import Image - -try: - import neptune - from neptune.experiments import Experiment -except ImportError: # pragma: no-cover - raise ImportError('You want to use `neptune` logger which is not installed yet,' # pragma: no-cover - ' install it with `pip install neptune-client`.') +from typing import Any, Dict, Iterable, Optional, Union import torch from torch import is_tensor -from pytorch_lightning import _logger as log -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities import _module_available, rank_zero_only + +log = logging.getLogger(__name__) +_NEPTUNE_AVAILABLE = _module_available("neptune") + +if _NEPTUNE_AVAILABLE: + import neptune + from neptune.experiments import Experiment +else: + # needed for test mocks, these tests shall be updated + neptune, Experiment = None, None class NeptuneLogger(LightningLoggerBase): r""" - Log using `Neptune `_. Install it with pip: + Log using `Neptune `_. + + Install it with pip: .. code-block:: bash pip install neptune-client The Neptune logger can be used in the online mode or offline (silent) mode. - To log experiment data in online mode, :class:`NeptuneLogger` requries an API key. - In offline mode, Neptune will log to a local directory. + To log experiment data in online mode, :class:`NeptuneLogger` requires an API key. + In offline mode, the logger does not connect to Neptune. **ONLINE MODE** - Example: - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.loggers import NeptuneLogger - >>> # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class - >>> # We are using an api_key for the anonymous user "neptuner" but you can use your own. - >>> neptune_logger = NeptuneLogger( - ... api_key='ANONYMOUS', - ... project_name='shared/pytorch-lightning-integration', - ... experiment_name='default', # Optional, - ... params={'max_epochs': 10}, # Optional, - ... tags=['pytorch-lightning', 'mlp'] # Optional, - ... ) - >>> trainer = Trainer(max_epochs=10, logger=neptune_logger) + .. testcode:: + + from pytorch_lightning import Trainer + from pytorch_lightning.loggers import NeptuneLogger + + # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class + # We are using an api_key for the anonymous user "neptuner" but you can use your own. + neptune_logger = NeptuneLogger( + api_key='ANONYMOUS', + project_name='shared/pytorch-lightning-integration', + experiment_name='default', # Optional, + params={'max_epochs': 10}, # Optional, + tags=['pytorch-lightning', 'mlp'] # Optional, + ) + trainer = Trainer(max_epochs=10, logger=neptune_logger) **OFFLINE MODE** - Example: - >>> from pytorch_lightning.loggers import NeptuneLogger - >>> # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class - >>> neptune_logger = NeptuneLogger( - ... offline_mode=True, - ... project_name='USER_NAME/PROJECT_NAME', - ... experiment_name='default', # Optional, - ... params={'max_epochs': 10}, # Optional, - ... tags=['pytorch-lightning', 'mlp'] # Optional, - ... ) - >>> trainer = Trainer(max_epochs=10, logger=neptune_logger) + .. testcode:: + + from pytorch_lightning.loggers import NeptuneLogger + + # arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class + neptune_logger = NeptuneLogger( + offline_mode=True, + project_name='USER_NAME/PROJECT_NAME', + experiment_name='default', # Optional, + params={'max_epochs': 10}, # Optional, + tags=['pytorch-lightning', 'mlp'] # Optional, + ) + trainer = Trainer(max_epochs=10, logger=neptune_logger) Use the logger anywhere in you :class:`~pytorch_lightning.core.lightning.LightningModule` as follows: - >>> from pytorch_lightning import LightningModule - >>> class LitModel(LightningModule): - ... def training_step(self, batch, batch_idx): - ... # log metrics - ... self.logger.experiment.log_metric('acc_train', ...) - ... # log images - ... self.logger.experiment.log_image('worse_predictions', ...) - ... # log model checkpoint - ... self.logger.experiment.log_artifact('model_checkpoint.pt', ...) - ... self.logger.experiment.whatever_neptune_supports(...) - ... - ... def any_lightning_module_function_or_hook(self): - ... self.logger.experiment.log_metric('acc_train', ...) - ... self.logger.experiment.log_image('worse_predictions', ...) - ... self.logger.experiment.log_artifact('model_checkpoint.pt', ...) - ... self.logger.experiment.whatever_neptune_supports(...) - - If you want to log objects after the training is finished use ``close_after_train=False``: + .. code-block:: python + + class LitModel(LightningModule): + def training_step(self, batch, batch_idx): + # log metrics + self.logger.experiment.log_metric('acc_train', ...) + # log images + self.logger.experiment.log_image('worse_predictions', ...) + # log model checkpoint + self.logger.experiment.log_artifact('model_checkpoint.pt', ...) + self.logger.experiment.whatever_neptune_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.log_metric('acc_train', ...) + self.logger.experiment.log_image('worse_predictions', ...) + self.logger.experiment.log_artifact('model_checkpoint.pt', ...) + self.logger.experiment.whatever_neptune_supports(...) + + If you want to log objects after the training is finished use ``close_after_fit=False``: .. code-block:: python @@ -135,7 +156,7 @@ class NeptuneLogger(LightningLoggerBase): "namespace/project_name" for example "tom/minst-classification". If ``None``, the value of `NEPTUNE_PROJECT` environment variable will be taken. You need to create the project in https://neptune.ai first. - offline_mode: Optional default False. If ``True`` no logs will be sent + offline_mode: Optional default ``False``. If ``True`` no logs will be sent to Neptune. Usually used for debug purposes. close_after_fit: Optional default ``True``. If ``False`` the experiment will not be closed after training and additional metrics, @@ -144,71 +165,62 @@ class NeptuneLogger(LightningLoggerBase): experiment_name: Optional. Editable name of the experiment. Name is displayed in the experiment’s Details (Metadata section) and in experiments view as a column. - upload_source_files: Optional. List of source files to be uploaded. - Must be list of str or single str. Uploaded sources are displayed - in the experiment’s Source code tab. - If ``None`` is passed, the Python file from which the experiment was created will be uploaded. - Pass an empty list (``[]``) to upload no files. - Unix style pathname pattern expansion is supported. - For example, you can pass ``'\*.py'`` - to upload all python source files from the current directory. - For recursion lookup use ``'\**/\*.py'`` (for Python 3.5 and later). - For more information see :mod:`glob` library. - params: Optional. Parameters of the experiment. - After experiment creation params are read-only. - Parameters are displayed in the experiment’s Parameters section and - each key-value pair can be viewed in the experiments view as a column. - properties: Optional. Default is ``{}``. Properties of the experiment. - They are editable after the experiment is created. - Properties are displayed in the experiment’s Details section and - each key-value pair can be viewed in the experiments view as a column. - tags: Optional. Default is ``[]``. Must be list of str. Tags of the experiment. - They are editable after the experiment is created (see: ``append_tag()`` and ``remove_tag()``). - Tags are displayed in the experiment’s Details section and can be viewed - in the experiments view as a column. + experiment_id: Optional. Default is ``None``. The ID of the existing experiment. + If specified, connect to experiment with experiment_id in project_name. + Input arguments "experiment_name", "params", "properties" and "tags" will be overriden based + on fetched experiment data. + prefix: A string to put at the beginning of metric keys. + \**kwargs: Additional arguments like `params`, `tags`, `properties`, etc. used by + :func:`neptune.Session.create_experiment` can be passed as keyword arguments in this logger. + + Raises: + ImportError: + If required Neptune package is not installed on the device. """ - def __init__(self, - api_key: Optional[str] = None, - project_name: Optional[str] = None, - close_after_fit: Optional[bool] = True, - offline_mode: bool = False, - experiment_name: Optional[str] = None, - upload_source_files: Optional[List[str]] = None, - params: Optional[Dict[str, Any]] = None, - properties: Optional[Dict[str, Any]] = None, - tags: Optional[List[str]] = None, - **kwargs): + + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + api_key: Optional[str] = None, + project_name: Optional[str] = None, + close_after_fit: Optional[bool] = True, + offline_mode: bool = False, + experiment_name: Optional[str] = None, + experiment_id: Optional[str] = None, + prefix: str = '', + **kwargs + ): + if neptune is None: + raise ImportError( + 'You want to use `neptune` logger which is not installed yet,' + ' install it with `pip install neptune-client`.' + ) super().__init__() self.api_key = api_key self.project_name = project_name self.offline_mode = offline_mode self.close_after_fit = close_after_fit self.experiment_name = experiment_name - self.upload_source_files = upload_source_files - self.params = params - self.properties = properties - self.tags = tags - self._experiment = None + self._prefix = prefix self._kwargs = kwargs + self.experiment_id = experiment_id + self._experiment = None - if offline_mode: - self.mode = 'offline' - neptune.init(project_qualified_name='dry-run/project', - backend=neptune.OfflineBackend()) - else: - self.mode = 'online' - neptune.init(api_token=self.api_key, - project_qualified_name=self.project_name) - - log.info(f'NeptuneLogger was initialized in {self.mode} mode') + log.info(f'NeptuneLogger will work in {"offline" if self.offline_mode else "online"} mode') def __getstate__(self): state = self.__dict__.copy() - # cannot be pickled + + # Experiment cannot be pickled, and additionally its ID cannot be pickled in offline mode state['_experiment'] = None + if self.offline_mode: + state['experiment_id'] = None + return state @property + @rank_zero_experiment def experiment(self) -> Experiment: r""" Actual Neptune object. To use neptune features in your @@ -220,14 +232,11 @@ def experiment(self) -> Experiment: """ + # Note that even though we initialize self._experiment in __init__, + # it may still end up being None after being pickled and un-pickled if self._experiment is None: - self._experiment = neptune.create_experiment( - name=self.experiment_name, - params=self.params, - properties=self.properties, - tags=self.tags, - upload_source_files=self.upload_source_files, - **self._kwargs) + self._experiment = self._create_or_get_experiment() + return self._experiment @rank_zero_only @@ -238,20 +247,21 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.set_property(f'param__{key}', val) @rank_zero_only - def log_metrics( - self, - metrics: Dict[str, Union[torch.Tensor, float]], - step: Optional[int] = None - ) -> None: + def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None: """ Log metrics (numeric values) in Neptune experiments. Args: metrics: Dictionary with metric names as keys and measured quantities as values - step: Step number at which the metrics should be recorded, must be strictly increasing + step: Step number at which the metrics should be recorded, currently ignored """ + assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' + + metrics = self._add_prefix(metrics) for key, val in metrics.items(): - self.log_metric(key, val, step=step) + # `step` is ignored because Neptune expects strictly increasing step values which + # Lighting does not always guarantee. + self.log_metric(key, val) @rank_zero_only def finalize(self, status: str) -> None: @@ -259,26 +269,28 @@ def finalize(self, status: str) -> None: if self.close_after_fit: self.experiment.stop() + @property + def save_dir(self) -> Optional[str]: + # Neptune does not save any local files + return None + @property def name(self) -> str: - if self.mode == 'offline': + if self.offline_mode: return 'offline-name' else: return self.experiment.name @property def version(self) -> str: - if self.mode == 'offline': + if self.offline_mode: return 'offline-id-1234' else: return self.experiment.id @rank_zero_only def log_metric( - self, - metric_name: str, - metric_value: Union[torch.Tensor, float, str], - step: Optional[int] = None + self, metric_name: str, metric_value: Union[torch.Tensor, float, str], step: Optional[int] = None ) -> None: """ Log metrics (numeric values) in Neptune experiments. @@ -306,13 +318,10 @@ def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None text: The value of the log (data-point). step: Step number at which the metrics should be recorded, must be strictly increasing """ - self.log_metric(log_name, text, step=step) + self.experiment.log_text(log_name, text, step=step) @rank_zero_only - def log_image(self, - log_name: str, - image: Union[str, Image, Any], - step: Optional[int] = None) -> None: + def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None) -> None: """ Log image data in Neptune experiment @@ -363,3 +372,22 @@ def append_tags(self, tags: Union[str, Iterable[str]]) -> None: if str(tags) == tags: tags = [tags] # make it as an iterable is if it is not yet self.experiment.append_tags(*tags) + + def _create_or_get_experiment(self): + if self.offline_mode: + project = neptune.Session(backend=neptune.OfflineBackend()).get_project('dry-run/project') + else: + session = neptune.Session.with_default_backend(api_token=self.api_key) + project = session.get_project(self.project_name) + + if self.experiment_id is None: + exp = project.create_experiment(name=self.experiment_name, **self._kwargs) + self.experiment_id = exp.id + else: + exp = project.get_experiments(id=self.experiment_id)[0] + self.experiment_name = exp.get_system_properties()['name'] + self.params = exp.get_parameters() + self.properties = exp.get_properties() + self.tags = exp.get_tags() + + return exp diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index fc33c9e942ecd4..4aaf64c6b7a599 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -1,35 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -TensorBoard ------------ +TensorBoard Logger +------------------ """ -import csv +import logging import os from argparse import Namespace -from typing import Optional, Dict, Union, Any -from warnings import warn +from typing import Any, Dict, Optional, Union import torch -from pkg_resources import parse_version from torch.utils.tensorboard import SummaryWriter +from torch.utils.tensorboard.summary import hparams -from pytorch_lightning import _logger as log -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.saving import save_hparams_to_yaml +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import get_filesystem + +log = logging.getLogger(__name__) + +if _OMEGACONF_AVAILABLE: + from omegaconf import Container, OmegaConf class TensorBoardLogger(LightningLoggerBase): r""" Log to local file system in `TensorBoard `_ format. + Implemented using :class:`~torch.utils.tensorboard.SummaryWriter`. Logs are saved to ``os.path.join(save_dir, name, version)``. This is the default logger in Lightning, it comes preinstalled. Example: - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.loggers import TensorBoardLogger - >>> logger = TensorBoardLogger("tb_logs", name="my_model") - >>> trainer = Trainer(logger=logger) + + .. testcode:: + + from pytorch_lightning import Trainer + from pytorch_lightning.loggers import TensorBoardLogger + logger = TensorBoardLogger("tb_logs", name="my_model") + trainer = Trainer(logger=logger) Args: save_dir: Save directory @@ -39,23 +62,40 @@ class TensorBoardLogger(LightningLoggerBase): directory for existing versions, then automatically assigns the next available version. If it is a string then it is used as the run-specific subdirectory name, otherwise ``'version_${version}'`` is used. - \**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor. + log_graph: Adds the computational graph to tensorboard. This requires that + the user has defined the `self.example_input_array` attribute in their + model. + default_hp_metric: Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is + called without a metric (otherwise calls to log_hyperparams without a metric are ignored). + prefix: A string to put at the beginning of metric keys. + \**kwargs: Additional arguments like `comment`, `filename_suffix`, etc. used by + :class:`SummaryWriter` can be passed as keyword arguments in this logger. """ - NAME_CSV_TAGS = 'meta_tags.csv' - - def __init__(self, - save_dir: str, - name: Optional[str] = "default", - version: Optional[Union[int, str]] = None, - **kwargs): + NAME_HPARAMS_FILE = 'hparams.yaml' + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + save_dir: str, + name: Optional[str] = "default", + version: Optional[Union[int, str]] = None, + log_graph: bool = False, + default_hp_metric: bool = True, + prefix: str = '', + **kwargs + ): super().__init__() - self.save_dir = save_dir - self._name = name + self._save_dir = save_dir + self._name = name or '' self._version = version + self._log_graph = log_graph + self._default_hp_metric = default_hp_metric + self._prefix = prefix + self._fs = get_filesystem(save_dir) self._experiment = None - self.tags = {} + self.hparams = {} self._kwargs = kwargs @property @@ -83,6 +123,11 @@ def log_dir(self) -> str: return log_dir @property + def save_dir(self) -> Optional[str]: + return self._save_dir + + @property + @rank_zero_experiment def experiment(self) -> SummaryWriter: r""" Actual tensorboard object. To use TensorBoard features in your @@ -96,69 +141,106 @@ def experiment(self) -> SummaryWriter: if self._experiment is not None: return self._experiment - os.makedirs(self.root_dir, exist_ok=True) + assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0' + if self.root_dir: + self._fs.makedirs(self.root_dir, exist_ok=True) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], - metrics: Optional[Dict[str, Any]] = None) -> None: + def log_hyperparams( + self, + params: Union[Dict[str, Any], Namespace], + metrics: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Record hyperparameters. TensorBoard logs with and without saved hyperparameters + are incompatible, the hyperparameters are then not displayed in the TensorBoard. + Please delete or move the previously saved logs to display the new ones with hyperparameters. + + Args: + params: a dictionary-like container with the hyperparameters + metrics: Dictionary with metric names as keys and measured quantities as values + """ + params = self._convert_params(params) - params = self._flatten_dict(params) - sanitized_params = self._sanitize_params(params) - - if parse_version(torch.__version__) < parse_version("1.3.0"): - warn( - f"Hyperparameter logging is not available for Torch version {torch.__version__}." - " Skipping log_hyperparams. Upgrade to Torch 1.3.0 or above to enable" - " hyperparameter logging." - ) + + # store params to output + if _OMEGACONF_AVAILABLE and isinstance(params, Container): + self.hparams = OmegaConf.merge(self.hparams, params) else: - from torch.utils.tensorboard.summary import hparams - if metrics is None: - metrics = {} - exp, ssi, sei = hparams(sanitized_params, metrics) + self.hparams.update(params) + + # format params into the suitable for tensorboard + params = self._flatten_dict(params) + params = self._sanitize_params(params) + + if metrics is None: + if self._default_hp_metric: + metrics = {"hp_metric": -1} + elif not isinstance(metrics, dict): + metrics = {"hp_metric": metrics} + + if metrics: + self.log_metrics(metrics, 0) + exp, ssi, sei = hparams(params, metrics) writer = self.experiment._get_file_writer() writer.add_summary(exp) writer.add_summary(ssi) writer.add_summary(sei) - # some alternative should be added - self.tags.update(sanitized_params) - @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' + + metrics = self._add_prefix(metrics) + for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() - self.experiment.add_scalar(k, v, step) + + if isinstance(v, dict): + self.experiment.add_scalars(k, v, step) + else: + try: + self.experiment.add_scalar(k, v, step) + # todo: specify the possible exception + except Exception as ex: + m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.' + type(ex)(ex.message + m) + + @rank_zero_only + def log_graph(self, model: LightningModule, input_array=None): + if self._log_graph: + if input_array is None: + input_array = model.example_input_array + + if input_array is not None: + input_array = model._apply_batch_transfer_handler(input_array) + self.experiment.add_graph(model, input_array) + else: + rank_zero_warn( + 'Could not log computational graph since the' + ' `model.example_input_array` attribute is not set' + ' or `input_array` was not given', UserWarning + ) @rank_zero_only def save(self) -> None: super().save() - try: - self.experiment.flush() - except AttributeError: - # you are using PT version ( None: + self.experiment.flush() + self.experiment.close() self.save() @property @@ -174,16 +256,23 @@ def version(self) -> int: def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - if not os.path.isdir(root_dir): + if not self._fs.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for d in os.listdir(root_dir): - if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): - existing_versions.append(int(d.split("_")[1])) - + for listing in self._fs.listdir(root_dir): + d = listing["name"] + bn = os.path.basename(d) + if self._fs.isdir(d) and bn.startswith("version_"): + dir_ver = bn.split("_")[1].replace('/', '') + existing_versions.append(int(dir_ver)) if len(existing_versions) == 0: return 0 return max(existing_versions) + 1 + + def __getstate__(self): + state = self.__dict__.copy() + state["_experiment"] = None + return state diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 7f382c3fb98756..84f231b0f16d7f 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -1,46 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Test Tube ---------- +Test Tube Logger +---------------- """ from argparse import Namespace -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union -try: - from test_tube import Experiment -except ImportError: # pragma: no-cover - raise ImportError('You want to use `test_tube` logger which is not installed yet,' # pragma: no-cover - ' install it with `pip install test-tube`.') +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn + +_TESTTUBE_AVAILABLE = _module_available("test_tube") -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities.distributed import rank_zero_only +if _TESTTUBE_AVAILABLE: + from test_tube import Experiment +else: + Experiment = None class TestTubeLogger(LightningLoggerBase): r""" Log to local file system in `TensorBoard `_ format but using a nicer folder structure (see `full docs `_). + Install it with pip: .. code-block:: bash pip install test_tube - Example: - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.loggers import TestTubeLogger - >>> logger = TestTubeLogger("tt_logs", name="my_exp_name") - >>> trainer = Trainer(logger=logger) + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.loggers import TestTubeLogger + logger = TestTubeLogger("tt_logs", name="my_exp_name") + trainer = Trainer(logger=logger) Use the logger anywhere in your :class:`~pytorch_lightning.core.lightning.LightningModule` as follows: - >>> from pytorch_lightning import LightningModule - >>> class LitModel(LightningModule): - ... def training_step(self, batch, batch_idx): - ... # example - ... self.logger.experiment.whatever_method_summary_writer_supports(...) - ... - ... def any_lightning_module_function_or_hook(self): - ... self.logger.experiment.add_histogram(...) + .. code-block:: python + + from pytorch_lightning import LightningModule + class LitModel(LightningModule): + def training_step(self, batch, batch_idx): + # example + self.logger.experiment.whatever_method_summary_writer_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.add_histogram(...) Args: save_dir: Save directory @@ -50,28 +70,48 @@ class TestTubeLogger(LightningLoggerBase): version: Experiment version. If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version. create_git_tag: If ``True`` creates a git tag to save the code used in this experiment. - + log_graph: Adds the computational graph to tensorboard. This requires that + the user has defined the `self.example_input_array` attribute in their + model. + prefix: A string to put at the beginning of metric keys. + + Raises: + ImportError: + If required TestTube package is not installed on the device. """ __test__ = False - - def __init__(self, - save_dir: str, - name: str = "default", - description: Optional[str] = None, - debug: bool = False, - version: Optional[int] = None, - create_git_tag: bool = False): + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + save_dir: str, + name: str = "default", + description: Optional[str] = None, + debug: bool = False, + version: Optional[int] = None, + create_git_tag: bool = False, + log_graph: bool = False, + prefix: str = '', + ): + if Experiment is None: + raise ImportError( + 'You want to use `test_tube` logger which is not installed yet,' + ' install it with `pip install test-tube`.' + ) super().__init__() - self.save_dir = save_dir + self._save_dir = save_dir self._name = name self.description = description self.debug = debug self._version = version self.create_git_tag = create_git_tag + self._log_graph = log_graph + self._prefix = prefix self._experiment = None @property + @rank_zero_experiment def experiment(self) -> Experiment: r""" @@ -108,9 +148,25 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: # TODO: HACK figure out where this is being set to true + metrics = self._add_prefix(metrics) self.experiment.debug = self.debug self.experiment.log(metrics, global_step=step) + @rank_zero_only + def log_graph(self, model: LightningModule, input_array=None): + if self._log_graph: + if input_array is None: + input_array = model.example_input_array + + if input_array is not None: + self.experiment.add_graph(model, model._apply_batch_transfer_handler(input_array)) + else: + rank_zero_warn( + 'Could not log computational graph since neither the' + ' `model.example_input_array` attribute is set nor' + ' `input_array` was given', UserWarning + ) + @rank_zero_only def save(self) -> None: super().save() @@ -135,19 +191,23 @@ def close(self) -> None: exp = self.experiment exp.close() + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property def name(self) -> str: if self._experiment is None: return self._name - else: - return self.experiment.name + + return self.experiment.name @property def version(self) -> int: if self._experiment is None: return self._version - else: - return self.experiment.version + + return self.experiment.version # Test tube experiments are not pickleable, so we need to override a few # methods to get DDP working. See diff --git a/pytorch_lightning/loggers/trains.py b/pytorch_lightning/loggers/trains.py deleted file mode 100644 index ca4f38c6340f89..00000000000000 --- a/pytorch_lightning/loggers/trains.py +++ /dev/null @@ -1,391 +0,0 @@ -""" -TRAINS ------- -""" -from argparse import Namespace -from os import environ -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import numpy as np -import torch -from PIL.Image import Image - -try: - import trains - from trains import Task -except ImportError: # pragma: no-cover - raise ImportError('You want to use `TRAINS` logger which is not installed yet,' # pragma: no-cover - ' install it with `pip install trains`.') - -from pytorch_lightning import _logger as log -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only - - -class TrainsLogger(LightningLoggerBase): - """ - Log using `allegro.ai TRAINS `_. Install it with pip: - - .. code-block:: bash - - pip install trains - - Example: - >>> from pytorch_lightning import Trainer - >>> from pytorch_lightning.loggers import TrainsLogger - >>> trains_logger = TrainsLogger( - ... project_name='pytorch lightning', - ... task_name='default', - ... output_uri='.', - ... ) # doctest: +ELLIPSIS - TRAINS Task: ... - TRAINS results page: ... - >>> trainer = Trainer(logger=trains_logger) - - Use the logger anywhere in your :class:`~pytorch_lightning.core.lightning.LightningModule` as follows: - - >>> from pytorch_lightning import LightningModule - >>> class LitModel(LightningModule): - ... def training_step(self, batch, batch_idx): - ... # example - ... self.logger.experiment.whatever_trains_supports(...) - ... - ... def any_lightning_module_function_or_hook(self): - ... self.logger.experiment.whatever_trains_supports(...) - - Args: - project_name: The name of the experiment's project. Defaults to ``None``. - task_name: The name of the experiment. Defaults to ``None``. - task_type: The name of the experiment. Defaults to ``'training'``. - reuse_last_task_id: Start with the previously used task id. Defaults to ``True``. - output_uri: Default location for output models. Defaults to ``None``. - auto_connect_arg_parser: Automatically grab the :class:`~argparse.ArgumentParser` - and connect it with the task. Defaults to ``True``. - auto_connect_frameworks: If ``True``, automatically patch to trains backend. Defaults to ``True``. - auto_resource_monitoring: If ``True``, machine vitals will be - sent along side the task scalars. Defaults to ``True``. - - Examples: - >>> logger = TrainsLogger("pytorch lightning", "default", output_uri=".") # doctest: +ELLIPSIS - TRAINS Task: ... - TRAINS results page: ... - >>> logger.log_metrics({"val_loss": 1.23}, step=0) - >>> logger.log_text("sample test") - sample test - >>> import numpy as np - >>> logger.log_artifact("confusion matrix", np.ones((2, 3))) - >>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8)) - """ - - _bypass = None - - def __init__( - self, - project_name: Optional[str] = None, - task_name: Optional[str] = None, - task_type: str = 'training', - reuse_last_task_id: bool = True, - output_uri: Optional[str] = None, - auto_connect_arg_parser: bool = True, - auto_connect_frameworks: bool = True, - auto_resource_monitoring: bool = True - ) -> None: - super().__init__() - if self.bypass_mode(): - self._trains = None - print('TRAINS Task: running in bypass mode') - print('TRAINS results page: disabled') - - class _TaskStub(object): - def __call__(self, *args, **kwargs): - return self - - def __getattr__(self, attr): - if attr in ('name', 'id'): - return '' - return self - - def __setattr__(self, attr, val): - pass - - self._trains = _TaskStub() - else: - self._trains = Task.init( - project_name=project_name, - task_name=task_name, - task_type=task_type, - reuse_last_task_id=reuse_last_task_id, - output_uri=output_uri, - auto_connect_arg_parser=auto_connect_arg_parser, - auto_connect_frameworks=auto_connect_frameworks, - auto_resource_monitoring=auto_resource_monitoring - ) - - @property - def experiment(self) -> Task: - r""" - Actual TRAINS object. To use TRAINS features in your - :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. - - Example:: - - self.logger.experiment.some_trains_function() - - """ - return self._trains - - @property - def id(self) -> Union[str, None]: - """ - ID is a uuid (string) representing this specific experiment in the entire system. - """ - if not self._trains: - return None - - return self._trains.id - - @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - """ - Log hyperparameters (numeric values) in TRAINS experiments. - - Args: - params: The hyperparameters that passed through the model. - """ - if not self._trains: - return - if not params: - return - - params = self._convert_params(params) - params = self._flatten_dict(params) - self._trains.connect(params) - - @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - """ - Log metrics (numeric values) in TRAINS experiments. - This method will be called by Trainer. - - Args: - metrics: The dictionary of the metrics. - If the key contains "/", it will be split by the delimiter, - then the elements will be logged as "title" and "series" respectively. - step: Step number at which the metrics should be recorded. Defaults to ``None``. - """ - if not self._trains: - return - - if not step: - step = self._trains.get_last_iteration() - - for k, v in metrics.items(): - if isinstance(v, str): - log.warning("Discarding metric with string value {}={}".format(k, v)) - continue - if isinstance(v, torch.Tensor): - v = v.item() - parts = k.split('/') - if len(parts) <= 1: - series = title = k - else: - title = parts[0] - series = '/'.join(parts[1:]) - self._trains.get_logger().report_scalar( - title=title, series=series, value=v, iteration=step) - - @rank_zero_only - def log_metric(self, title: str, series: str, value: float, step: Optional[int] = None) -> None: - """ - Log metrics (numeric values) in TRAINS experiments. - This method will be called by the users. - - Args: - title: The title of the graph to log, e.g. loss, accuracy. - series: The series name in the graph, e.g. classification, localization. - value: The value to log. - step: Step number at which the metrics should be recorded. Defaults to ``None``. - """ - if not self._trains: - return - - if not step: - step = self._trains.get_last_iteration() - - if isinstance(value, torch.Tensor): - value = value.item() - self._trains.get_logger().report_scalar( - title=title, series=series, value=value, iteration=step) - - @rank_zero_only - def log_text(self, text: str) -> None: - """Log console text data in TRAINS experiment. - - Args: - text: The value of the log (data-point). - """ - if self.bypass_mode(): - print(text) - return - - if not self._trains: - return - - self._trains.get_logger().report_text(text) - - @rank_zero_only - def log_image( - self, title: str, series: str, - image: Union[str, np.ndarray, Image, torch.Tensor], - step: Optional[int] = None) -> None: - """ - Log Debug image in TRAINS experiment - - Args: - title: The title of the debug image, i.e. "failed", "passed". - series: The series name of the debug image, i.e. "Image 0", "Image 1". - image: Debug image to log. If :class:`numpy.ndarray` or :class:`torch.Tensor`, - the image is assumed to be the following: - - - shape: CHW - - color space: RGB - - value range: [0., 1.] (float) or [0, 255] (uint8) - - step: Step number at which the metrics should be recorded. Defaults to None. - """ - if not self._trains: - return - - if not step: - step = self._trains.get_last_iteration() - - if isinstance(image, str): - self._trains.get_logger().report_image( - title=title, series=series, local_path=image, iteration=step) - else: - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - if isinstance(image, np.ndarray): - image = image.transpose(1, 2, 0) - self._trains.get_logger().report_image( - title=title, series=series, image=image, iteration=step) - - @rank_zero_only - def log_artifact( - self, name: str, - artifact: Union[str, Path, Dict[str, Any], np.ndarray, Image], - metadata: Optional[Dict[str, Any]] = None, delete_after_upload: bool = False) -> None: - """ - Save an artifact (file/object) in TRAINS experiment storage. - - Args: - name: Artifact name. Notice! it will override the previous artifact - if the name already exists. - artifact: Artifact object to upload. Currently supports: - - - string / :class:`pathlib.Path` are treated as path to artifact file to upload - If a wildcard or a folder is passed, a zip file containing the - local files will be created and uploaded. - - dict will be stored as .json file and uploaded - - :class:`pandas.DataFrame` will be stored as .csv.gz (compressed CSV file) and uploaded - - :class:`numpy.ndarray` will be stored as .npz and uploaded - - :class:`PIL.Image.Image` will be stored to .png file and uploaded - - metadata: - Simple key/value dictionary to store on the artifact. Defaults to ``None``. - delete_after_upload: - If ``True``, the local artifact will be deleted (only applies if ``artifact`` is a - local file). Defaults to ``False``. - """ - if not self._trains: - return - - self._trains.upload_artifact( - name=name, artifact_object=artifact, metadata=metadata, - delete_after_upload=delete_after_upload - ) - - @rank_zero_only - def finalize(self, status: str = None) -> None: - # super().finalize(status) - if self.bypass_mode() or not self._trains: - return - - self._trains.close() - self._trains = None - - @property - def name(self) -> Union[str, None]: - """ - Name is a human readable non-unique name (str) of the experiment. - """ - if not self._trains: - return '' - - return self._trains.name - - @property - def version(self) -> Union[str, None]: - if not self._trains: - return None - - return self._trains.id - - @classmethod - def set_credentials(cls, api_host: str = None, web_host: str = None, files_host: str = None, - key: str = None, secret: str = None) -> None: - """ - Set new default TRAINS-server host and credentials. - These configurations could be overridden by either OS environment variables - or trains.conf configuration file. - - Note: - Credentials need to be set *prior* to Logger initialization. - - Args: - api_host: Trains API server url, example: ``host='http://localhost:8008'`` - web_host: Trains WEB server url, example: ``host='http://localhost:8080'`` - files_host: Trains Files server url, example: ``host='http://localhost:8081'`` - key: user key/secret pair, example: ``key='thisisakey123'`` - secret: user key/secret pair, example: ``secret='thisisseceret123'`` - """ - Task.set_credentials(api_host=api_host, web_host=web_host, files_host=files_host, - key=key, secret=secret) - - @classmethod - def set_bypass_mode(cls, bypass: bool) -> None: - """ - Will bypass all outside communication, and will drop all logs. - Should only be used in "standalone mode", when there is no access to the *trains-server*. - - Args: - bypass: If ``True``, all outside communication is skipped. - """ - cls._bypass = bypass - - @classmethod - def bypass_mode(cls) -> bool: - """ - Returns the bypass mode state. - - Note: - `GITHUB_ACTIONS` env will automatically set bypass_mode to ``True`` - unless overridden specifically with ``TrainsLogger.set_bypass_mode(False)``. - - Return: - If True, all outside communication is skipped. - """ - return cls._bypass if cls._bypass is not None else bool(environ.get('CI')) - - def __getstate__(self) -> Union[str, None]: - if self.bypass_mode() or not self._trains: - return '' - - return self._trains.id - - def __setstate__(self, state: str) -> None: - self._rank = 0 - self._trains = None - if state: - self._trains = Task.get_task(task_id=state) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 0d5ff9855a40df..285388d6c67653 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -1,27 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Weights and Biases ------------------- +Weights and Biases Logger +------------------------- """ import os from argparse import Namespace -from typing import Optional, List, Dict, Union, Any +from typing import Any, Dict, Optional, Union import torch.nn as nn +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities import _module_available, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() + +_WANDB_AVAILABLE = _module_available("wandb") + try: import wandb from wandb.wandb_run import Run -except ImportError: # pragma: no-cover - raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover - ' install it with `pip install wandb`.') - -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only +except ImportError: + # needed for test mocks, these tests shall be updated + wandb, Run = None, None class WandbLogger(LightningLoggerBase): - """ - Log using `Weights and Biases `_. Install it with pip: + r""" + Log using `Weights and Biases `_. + + Install it with pip: .. code-block:: bash @@ -29,53 +50,88 @@ class WandbLogger(LightningLoggerBase): Args: name: Display name for the run. - save_dir: Path where data is saved. + save_dir: Path where data is saved (wandb dir by default). offline: Run offline (data can be streamed later to wandb servers). id: Sets the version, mainly used to resume a previous run. + version: Same as id. anonymous: Enables or explicitly disables anonymous logging. - version: Sets the version, mainly used to resume a previous run. project: The name of the project to which this run will belong. - tags: Tags associated with this run. log_model: Save checkpoints in wandb dir to upload on W&B servers. - experiment: WandB experiment object - entity: The team posting this run (default: your username or your default team) + prefix: A string to put at the beginning of metric keys. + experiment: WandB experiment object. Automatically set when creating a run. + \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by + :func:`wandb.init` can be passed as keyword arguments in this logger. - Example: - >>> from pytorch_lightning.loggers import WandbLogger - >>> from pytorch_lightning import Trainer - >>> wandb_logger = WandbLogger() - >>> trainer = Trainer(logger=wandb_logger) + Raises: + ImportError: + If required WandB package is not installed on the device. + MisconfigurationException: + If both ``log_model`` and ``offline``is set to ``True``. + + Example:: + + from pytorch_lightning.loggers import WandbLogger + from pytorch_lightning import Trainer + wandb_logger = WandbLogger() + trainer = Trainer(logger=wandb_logger) + + Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`, + make sure to use `commit=False` so the logging step does not increase. See Also: - - `Tutorial `__ - on how to use W&B with Pytorch Lightning. + - `Tutorial `__ + on how to use W&B with PyTorch Lightning + - `W&B Documentation `__ """ - def __init__(self, - name: Optional[str] = None, - save_dir: Optional[str] = None, - offline: bool = False, - id: Optional[str] = None, - anonymous: bool = False, - version: Optional[str] = None, - project: Optional[str] = None, - tags: Optional[List[str]] = None, - log_model: bool = False, - experiment=None, - entity=None): + LOGGER_JOIN_CHAR = '-' + + def __init__( + self, + name: Optional[str] = None, + save_dir: Optional[str] = None, + offline: Optional[bool] = False, + id: Optional[str] = None, + anonymous: Optional[bool] = False, + version: Optional[str] = None, + project: Optional[str] = None, + log_model: Optional[bool] = False, + experiment=None, + prefix: Optional[str] = '', + sync_step: Optional[bool] = None, + **kwargs + ): + if wandb is None: + raise ImportError( + 'You want to use `wandb` logger which is not installed yet,' # pragma: no-cover + ' install it with `pip install wandb`.' + ) + + if offline and log_model: + raise MisconfigurationException( + f'Providing log_model={log_model} and offline={offline} is an invalid configuration' + ' since model checkpoints cannot be uploaded in offline mode.\n' + 'Hint: Set `offline=False` to log your model.' + ) + + if sync_step is not None: + warning_cache.warn( + "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5." + " Metrics are now logged separately and automatically synchronized.", DeprecationWarning + ) + super().__init__() self._name = name self._save_dir = save_dir - self._anonymous = 'allow' if anonymous else None + self._offline = offline self._id = version or id - self._tags = tags + self._anonymous = 'allow' if anonymous else None self._project = project - self._experiment = experiment - self._offline = offline - self._entity = entity self._log_model = log_model + self._prefix = prefix + self._experiment = experiment + self._kwargs = kwargs def __getstate__(self): state = self.__dict__.copy() @@ -87,6 +143,7 @@ def __getstate__(self): return state @property + @rank_zero_experiment def experiment(self) -> Run: r""" @@ -102,11 +159,24 @@ def experiment(self) -> Run: if self._offline: os.environ['WANDB_MODE'] = 'dryrun' self._experiment = wandb.init( - name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, - reinit=True, id=self._id, resume='allow', tags=self._tags, entity=self._entity) + name=self._name, + dir=self._save_dir, + project=self._project, + anonymous=self._anonymous, + id=self._id, + resume='allow', + **self._kwargs + ) if wandb.run is None else wandb.run + # save checkpoints in wandb dir to upload on W&B servers - if self._log_model: - self.save_dir = self._experiment.dir + if self._save_dir is None: + self._save_dir = self._experiment.dir + + # define default x-axis (for latest wandb versions) + if getattr(self._experiment, "define_metric", None): + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True) + return self._experiment def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): @@ -115,19 +185,36 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) + params = self._flatten_dict(params) + params = self._sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - self.experiment.log(metrics, step=step) + assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' + + metrics = self._add_prefix(metrics) + if step is not None: + self.experiment.log({**metrics, 'trainer/global_step': step}) + else: + self.experiment.log(metrics) + + @property + def save_dir(self) -> Optional[str]: + return self._save_dir @property - def name(self) -> str: + def name(self) -> Optional[str]: # don't create an experiment if we don't have one - name = self._experiment.project_name() if self._experiment else None - return name + return self._experiment.project_name() if self._experiment else self._name @property - def version(self) -> str: + def version(self) -> Optional[str]: # don't create an experiment if we don't have one - return self._experiment.id if self._experiment else None + return self._experiment.id if self._experiment else self._id + + @rank_zero_only + def finalize(self, status: str) -> None: + # upload all checkpoints from saving dir + if self._log_model: + wandb.save(os.path.join(self.save_dir, "*.ckpt")) diff --git a/pytorch_lightning/logging/__init__.py b/pytorch_lightning/logging/__init__.py deleted file mode 100644 index 9d027b34615590..00000000000000 --- a/pytorch_lightning/logging/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `logging` package has been renamed to `loggers` since v0.7.0. - The deprecated package name will be removed in v0.9.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`logging` package has been renamed to `loggers` since v0.7.0" - " The deprecated package name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.loggers import * # noqa: F403 diff --git a/pytorch_lightning/logging/comet.py b/pytorch_lightning/logging/comet.py deleted file mode 100644 index ce854292441982..00000000000000 --- a/pytorch_lightning/logging/comet.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0 -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`logging.comet` module has been renamed to `loggers.comet` since v0.7.0." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.loggers.comet import CometLogger # noqa: F403 diff --git a/pytorch_lightning/logging/comet_logger.py b/pytorch_lightning/logging/comet_logger.py deleted file mode 100644 index 83360b36f8bc97..00000000000000 --- a/pytorch_lightning/logging/comet_logger.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `comet_logger` module has been renamed to `comet` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`comet_logger` module has been renamed to `comet` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.loggers.comet import CometLogger # noqa: E402 diff --git a/pytorch_lightning/logging/mlflow.py b/pytorch_lightning/logging/mlflow.py deleted file mode 100644 index 15b7fd81ca257f..00000000000000 --- a/pytorch_lightning/logging/mlflow.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0 -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`logging.mlflow` module has been renamed to `loggers.mlflow` since v0.7.0." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403 diff --git a/pytorch_lightning/logging/mlflow_logger.py b/pytorch_lightning/logging/mlflow_logger.py deleted file mode 100644 index 2e1b52126ecdf0..00000000000000 --- a/pytorch_lightning/logging/mlflow_logger.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `mlflow_logger` module has been renamed to `mlflow` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`mlflow_logger` module has been renamed to `mlflow` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: E402 diff --git a/pytorch_lightning/logging/neptune.py b/pytorch_lightning/logging/neptune.py deleted file mode 100644 index af6e18c12cbd08..00000000000000 --- a/pytorch_lightning/logging/neptune.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0 -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`logging.neptune` module has been renamed to `loggers.neptune` since v0.7.0." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403 diff --git a/pytorch_lightning/logging/test_tube.py b/pytorch_lightning/logging/test_tube.py deleted file mode 100644 index 3648db6186a225..00000000000000 --- a/pytorch_lightning/logging/test_tube.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0 -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`logging.test_tube` module has been renamed to `loggers.test_tube` since v0.7.0." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403 diff --git a/pytorch_lightning/logging/test_tube_logger.py b/pytorch_lightning/logging/test_tube_logger.py deleted file mode 100644 index 3280ac8dce6328..00000000000000 --- a/pytorch_lightning/logging/test_tube_logger.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `test_tube_logger` module has been renamed to `test_tube` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`test_tube_logger` module has been renamed to `test_tube` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: E402 diff --git a/pytorch_lightning/logging/wandb.py b/pytorch_lightning/logging/wandb.py deleted file mode 100644 index 98a753c0aa1e4b..00000000000000 --- a/pytorch_lightning/logging/wandb.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -.. warning:: `logging` package has been renamed to `loggers` since v0.7.0 and will be removed in v0.9.0 -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`logging.wandb` module has been renamed to `loggers.wandb` since v0.7.0." - " The deprecated module name will be removed in v0.9.0.", DeprecationWarning) - -from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403 diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py new file mode 100644 index 00000000000000..9b27fdf0cb253d --- /dev/null +++ b/pytorch_lightning/metrics/__init__.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.metrics.classification import ( # noqa: F401 + Accuracy, + AUC, + AUROC, + AveragePrecision, + ConfusionMatrix, + F1, + FBeta, + HammingDistance, + IoU, + Precision, + PrecisionRecallCurve, + Recall, + ROC, + StatScores, +) +from pytorch_lightning.metrics.metric import Metric, MetricCollection # noqa: F401 +from pytorch_lightning.metrics.regression import ( # noqa: F401 + ExplainedVariance, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, + PSNR, + R2Score, + SSIM, +) +from pytorch_lightning.utilities import rank_zero_deprecation + +rank_zero_deprecation( + "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" + " (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5" +) diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py new file mode 100644 index 00000000000000..3aeffe3bbc6931 --- /dev/null +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401 +from pytorch_lightning.metrics.classification.auc import AUC # noqa: F401 +from pytorch_lightning.metrics.classification.auroc import AUROC # noqa: F401 +from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401 +from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 +from pytorch_lightning.metrics.classification.f_beta import F1, FBeta # noqa: F401 +from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401 +from pytorch_lightning.metrics.classification.iou import IoU # noqa: F401 +from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401 +from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 +from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401 +from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401 diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py new file mode 100644 index 00000000000000..1a9febe0c831c6 --- /dev/null +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import Accuracy as _Accuracy + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class Accuracy(_Accuracy): + + @deprecated_metrics(target=_Accuracy) + def __init__( + self, + threshold: float = 0.5, + top_k: Optional[int] = None, + subset_accuracy: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.Accuracy`. + + .. deprecated:: + Use :class:`~torchmetrics.Accuracy`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py new file mode 100644 index 00000000000000..05bc7b27d7e686 --- /dev/null +++ b/pytorch_lightning/metrics/classification/auc.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import AUC as _AUC + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class AUC(_AUC): + + @deprecated_metrics(target=_AUC) + def __init__( + self, + reorder: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.AUC`. + + .. deprecated:: + Use :class:`~torchmetrics.AUC`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py new file mode 100644 index 00000000000000..e10b094fd5a2e2 --- /dev/null +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import AUROC as _AUROC + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class AUROC(_AUROC): + + @deprecated_metrics(target=_AUROC) + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = 'macro', + max_fpr: Optional[float] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.AUROC`. + + .. deprecated:: + Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py new file mode 100644 index 00000000000000..6c8cdbd52891d3 --- /dev/null +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchmetrics import AveragePrecision as _AveragePrecision + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class AveragePrecision(_AveragePrecision): + + @deprecated_metrics(target=_AveragePrecision) + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.AveragePrecision`. + + .. deprecated:: + Use :class:`~torchmetrics.AveragePrecision`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py new file mode 100644 index 00000000000000..2995f668380deb --- /dev/null +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchmetrics import ConfusionMatrix as _ConfusionMatrix + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class ConfusionMatrix(_ConfusionMatrix): + + @deprecated_metrics(target=_ConfusionMatrix) + def __init__( + self, + num_classes: int, + normalize: Optional[str] = None, + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.ConfusionMatrix`. + + .. deprecated:: + Use :class:`~torchmetrics.ConfusionMatrix`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py new file mode 100644 index 00000000000000..a3f4172f054008 --- /dev/null +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchmetrics import F1 as _F1 +from torchmetrics import FBeta as _FBeta + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class FBeta(_FBeta): + + @deprecated_metrics(target=_FBeta) + def __init__( + self, + num_classes: int, + beta: float = 1.0, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.FBeta`. + + .. deprecated:: + Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0. + """ + + +class F1(_F1): + + @deprecated_metrics(target=_F1) + def __init__( + self, + num_classes: int, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.F1`. + + .. deprecated:: + Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py new file mode 100644 index 00000000000000..d66b0c2d9cfa84 --- /dev/null +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import HammingDistance as _HammingDistance + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class HammingDistance(_HammingDistance): + + @deprecated_metrics(target=_HammingDistance) + def __init__( + self, + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.HammingDistance`. + + .. deprecated:: + Use :class:`~torchmetrics.HammingDistance`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/iou.py b/pytorch_lightning/metrics/classification/iou.py new file mode 100644 index 00000000000000..f1d9d0945511a8 --- /dev/null +++ b/pytorch_lightning/metrics/classification/iou.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchmetrics import IoU as _IoU + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class IoU(_IoU): + + @deprecated_metrics(target=_IoU) + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + reduction: str = 'elementwise_mean', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.IoU`. + + .. deprecated:: + Use :class:`~torchmetrics.IoU`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py new file mode 100644 index 00000000000000..7b95d21dae97c1 --- /dev/null +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -0,0 +1,71 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import Precision as _Precision +from torchmetrics import Recall as _Recall + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class Precision(_Precision): + + @deprecated_metrics(target=_Precision) + def __init__( + self, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False, + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + is_multiclass: Optional[bool] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.Precision`. + + .. deprecated:: + Use :class:`~torchmetrics.Precision`. Will be removed in v1.5.0. + """ + + +class Recall(_Recall): + + @deprecated_metrics(target=_Recall) + def __init__( + self, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False, + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + is_multiclass: Optional[bool] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.Recall`. + + .. deprecated:: + Use :class:`~torchmetrics.Recall`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py new file mode 100644 index 00000000000000..285cb2fb78ccc1 --- /dev/null +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class PrecisionRecallCurve(_PrecisionRecallCurve): + + @deprecated_metrics(target=_PrecisionRecallCurve) + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.PrecisionRecallCurve`. + + .. deprecated:: + Use :class:`~torchmetrics.PrecisionRecallCurve`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py new file mode 100644 index 00000000000000..3f6cf50803c869 --- /dev/null +++ b/pytorch_lightning/metrics/classification/roc.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from torchmetrics import ROC as _ROC + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class ROC(_ROC): + + @deprecated_metrics(target=_ROC) + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.ROC`. + + .. deprecated:: + Use :class:`~torchmetrics.ROC`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py new file mode 100644 index 00000000000000..1eed815d4b4cd1 --- /dev/null +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import StatScores as _StatScores + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class StatScores(_StatScores): + + @deprecated_metrics(target=_StatScores) + def __init__( + self, + threshold: float = 0.5, + top_k: Optional[int] = None, + reduce: str = "micro", + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + mdmc_reduce: Optional[str] = None, + is_multiclass: Optional[bool] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.StatScores`. + + .. deprecated:: + Use :class:`~torchmetrics.StatScores`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py new file mode 100644 index 00000000000000..56bb1912e48e66 --- /dev/null +++ b/pytorch_lightning/metrics/compositional.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Union + +import torch +from torchmetrics import Metric +from torchmetrics.metric import CompositionalMetric as _CompositionalMetric + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class CompositionalMetric(_CompositionalMetric): + + @deprecated_metrics(target=_CompositionalMetric) + def __init__( + self, + operator: Callable, + metric_a: Union[Metric, int, float, torch.Tensor], + metric_b: Union[Metric, int, float, torch.Tensor, None], + ): + """ + .. deprecated:: + Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py new file mode 100644 index 00000000000000..3b31dad5d3411d --- /dev/null +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401 +from pytorch_lightning.metrics.functional.auc import auc # noqa: F401 +from pytorch_lightning.metrics.functional.auroc import auroc # noqa: F401 +from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401 +from pytorch_lightning.metrics.functional.classification import ( # noqa: F401 + dice_score, + get_num_classes, + multiclass_auroc, + stat_scores_multiple_classes, + to_categorical, +) +from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401 +from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401 +from pytorch_lightning.metrics.functional.f_beta import f1, fbeta # noqa: F401 +from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401 +from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401 +from pytorch_lightning.metrics.functional.iou import iou # noqa: F401 +from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401 +from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 +from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 +from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401 +from pytorch_lightning.metrics.functional.precision_recall import precision, precision_recall, recall # noqa: F401 +from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401 +from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401 +from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401 +from pytorch_lightning.metrics.functional.roc import roc # noqa: F401 +from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401 +from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401 +from pytorch_lightning.metrics.functional.stat_scores import stat_scores # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py new file mode 100644 index 00000000000000..69fa9d75590e0e --- /dev/null +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torchmetrics.functional import accuracy as _accuracy + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_accuracy) +def accuracy( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: Optional[int] = None, + subset_accuracy: bool = False, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py new file mode 100644 index 00000000000000..7cc6aa458d397a --- /dev/null +++ b/pytorch_lightning/metrics/functional/auc.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torchmetrics.functional import auc as _auc + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_auc) +def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.auc`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py new file mode 100644 index 00000000000000..c49aa1a8fdc48d --- /dev/null +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Sequence + +import torch +from torchmetrics.functional import auroc as _auroc + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_auroc) +def auroc( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = 'macro', + max_fpr: Optional[float] = None, + sample_weights: Optional[Sequence] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py new file mode 100644 index 00000000000000..017b34739a0f40 --- /dev/null +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Sequence, Union + +import torch +from torchmetrics.functional import average_precision as _average_precision + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_average_precision) +def average_precision( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[List[torch.Tensor], torch.Tensor]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.average_precision`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py new file mode 100644 index 00000000000000..be1fec196a3465 --- /dev/null +++ b/pytorch_lightning/metrics/functional/classification.py @@ -0,0 +1,352 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torchmetrics.utilities import class_reduce, reduce +from torchmetrics.utilities.data import get_num_classes, to_categorical + +from pytorch_lightning.metrics.functional.auc import auc as __auc +from pytorch_lightning.metrics.functional.auroc import auroc as __auroc +from pytorch_lightning.metrics.functional.iou import iou as __iou +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn + + +def stat_scores( + pred: torch.Tensor, + target: torch.Tensor, + class_index: int, + argmax_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. + """ + if pred.ndim == target.ndim + 1: + pred = to_categorical(pred, argmax_dim=argmax_dim) + + tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum() + fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum() + tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum() + fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum() + sup = (target == class_index).to(torch.long).sum() + + return tp, fp, tn, fn, sup + + +# todo: remove in 1.4 +def stat_scores_multiple_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + argmax_dim: int = 1, + reduction: str = 'none', +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `stat_scores_multiple_classes` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional import stat_scores`." + " It will be removed in v1.4.0" + ) + if pred.ndim == target.ndim + 1: + pred = to_categorical(pred, argmax_dim=argmax_dim) + + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) + + if pred.dtype != torch.bool: + pred = pred.clamp_max(max=num_classes) + if target.dtype != torch.bool: + target = target.clamp_max(max=num_classes) + + possible_reductions = ('none', 'sum', 'elementwise_mean') + if reduction not in possible_reductions: + raise ValueError("reduction type %s not supported" % reduction) + + if reduction == 'none': + pred = pred.view((-1, )).long() + target = target.view((-1, )).long() + + tps = torch.zeros((num_classes + 1, ), device=pred.device) + fps = torch.zeros((num_classes + 1, ), device=pred.device) + fns = torch.zeros((num_classes + 1, ), device=pred.device) + sups = torch.zeros((num_classes + 1, ), device=pred.device) + + match_true = (pred == target).float() + match_false = 1 - match_true + + tps.scatter_add_(0, pred, match_true) + fps.scatter_add_(0, pred, match_false) + fns.scatter_add_(0, target, match_false) + tns = pred.size(0) - (tps + fps + fns) + sups.scatter_add_(0, target, torch.ones_like(match_true)) + + tps = tps[:num_classes] + fps = fps[:num_classes] + tns = tns[:num_classes] + fns = fns[:num_classes] + sups = sups[:num_classes] + + elif reduction == 'sum' or reduction == 'elementwise_mean': + count_match_true = (pred == target).sum().float() + oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim) + + tps = count_match_true - oob_tp + fps = pred.nelement() - count_match_true - oob_fp + fns = pred.nelement() - count_match_true - oob_fn + tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn) + sups = pred.nelement() - oob_sup.float() + + if reduction == 'elementwise_mean': + tps /= num_classes + fps /= num_classes + fns /= num_classes + tns /= num_classes + sups /= num_classes + + return tps.float(), fps.float(), tns.float(), fns.float(), sups.float() + + +def _confmat_normalize(cm): + """ Normalization function for confusion matrix """ + cm = cm / cm.sum(-1, keepdim=True) + nan_elements = cm[torch.isnan(cm)].nelement() + if nan_elements != 0: + cm[torch.isnan(cm)] = 0 + rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') + return cm + + +# todo: remove in 1.4 +def precision_recall( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', + return_support: bool = False, + return_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `precision_recall` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrcs.functional import precision_recall`." + " It will be removed in v1.4.0" + ) + + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) + + precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) + recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) + if return_state: + return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups} + if return_support: + return precision, recall, sups + return precision, recall + + +# todo: remove in 1.4 +def precision( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.precision`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `precision` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional import precision`." + " It will be removed in v1.4.0" + ) + + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] + + +# todo: remove in 1.4 +def recall( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.recall`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `recall` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional import recall`." + " It will be removed in v1.4.0" + ) + + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] + + +# todo: remove in 1.4 +def auc( + x: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.auc`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `auc` was deprecated in v1.2.0 in favor of" + " `pytorch_lightning.metrics.functional.auc import auc`." + " It will be removed in v1.4.0" + ) + return __auc(x, y) + + +# todo: remove in 1.4 +def _auc_decorator() -> Callable: + + def wrapper(func_to_decorate: Callable) -> Callable: + + @wraps(func_to_decorate) + def new_func(*args, **kwargs) -> torch.Tensor: + x, y = func_to_decorate(*args, **kwargs)[:2] + + return auc(x, y) + + return new_func + + return wrapper + + +# todo: remove in 1.4 +def _multiclass_auc_decorator() -> Callable: + + def wrapper(func_to_decorate: Callable) -> Callable: + + @wraps(func_to_decorate) + def new_func(*args, **kwargs) -> torch.Tensor: + results = [] + for class_result in func_to_decorate(*args, **kwargs): + x, y = class_result[:2] + results.append(auc(x, y)) + + return torch.stack(results) + + return new_func + + return wrapper + + +# todo: remove in 1.4 +def auroc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., + max_fpr: float = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `auroc` was deprecated in v1.2.0 in favor of `pytorch_lightning.metrics.functional.auroc import auroc`." + " It will be removed in v1.4.0" + ) + return __auroc( + preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, num_classes=1 + ) + + +# todo: remove in 1.4 +def multiclass_auroc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `multiclass_auroc` was deprecated in v1.2.0 in favor of" + " `pytorch_lightning.metrics.functional.auroc import auroc`." + " It will be removed in v1.4.0" + ) + + return __auroc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) + + +def dice_score( + pred: torch.Tensor, + target: torch.Tensor, + bg: bool = False, + nan_score: float = 0.0, + no_fg_score: float = 0.0, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.dice_score`. Will be removed in v1.4.0. + """ + num_classes = pred.shape[1] + bg = (1 - int(bool(bg))) + scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) + for i in range(bg, num_classes): + if not (target == i).any(): + # no foreground class + scores[i - bg] += no_fg_score + continue + + tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) + denom = (2 * tp + fp + fn).to(torch.float) + # nan result + score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score + + scores[i - bg] += score_cls + return reduce(scores, reduction=reduction) + + +# todo: remove in 1.4 +def iou( + pred: torch.Tensor, + target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.iou`. Will be removed in v1.4.0. + """ + rank_zero_deprecation( + "This `iou` was deprecated in v1.2.0 in favor of `from pytorch_lightning.metrics.functional.iou import iou`." + " It will be removed in v1.4.0" + ) + return __iou( + pred=pred, + target=target, + ignore_index=ignore_index, + absent_score=absent_score, + threshold=0.5, + num_classes=num_classes, + reduction=reduction + ) diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py new file mode 100644 index 00000000000000..038bd8b49b7306 --- /dev/null +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torchmetrics.functional import confusion_matrix as _confusion_matrix + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_confusion_matrix) +def confusion_matrix( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + normalize: Optional[str] = None, + threshold: float = 0.5 +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.confusion_matrix`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py new file mode 100644 index 00000000000000..233a0851b8d56d --- /dev/null +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Union + +import torch +from torchmetrics.functional import explained_variance as _explained_variance + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_explained_variance) +def explained_variance( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py new file mode 100644 index 00000000000000..f994c9a8a3271b --- /dev/null +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -0,0 +1,49 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torchmetrics.functional import f1 as _f1 +from torchmetrics.functional import fbeta as _fbeta + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_fbeta) +def fbeta( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + beta: float = 1.0, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_f1) +def f1( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.f1`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py new file mode 100644 index 00000000000000..6a390e776f1116 --- /dev/null +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torchmetrics.functional import hamming_distance as _hamming_distance + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_hamming_distance) +def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.hamming_distance`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/image_gradients.py b/pytorch_lightning/metrics/functional/image_gradients.py new file mode 100644 index 00000000000000..e2151c5fc1d938 --- /dev/null +++ b/pytorch_lightning/metrics/functional/image_gradients.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torchmetrics.functional import image_gradients as _image_gradients + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_image_gradients) +def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.image_gradients`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/iou.py b/pytorch_lightning/metrics/functional/iou.py new file mode 100644 index 00000000000000..76f59854ad4bf5 --- /dev/null +++ b/pytorch_lightning/metrics/functional/iou.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torchmetrics.functional import iou as _iou + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_iou) +def iou( + pred: torch.Tensor, + target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.iou`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py new file mode 100644 index 00000000000000..219284d79d6234 --- /dev/null +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics.functional import mean_absolute_error as _mean_absolute_error + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_mean_absolute_error) +def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.mean_absolute_error`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py new file mode 100644 index 00000000000000..329fe040ebc7db --- /dev/null +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_mean_relative_error) +def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.regression.mean_relative_error`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py new file mode 100644 index 00000000000000..5bbc0bb1c6a83d --- /dev/null +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics.functional import mean_squared_error as _mean_squared_error + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_mean_squared_error) +def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_error`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py new file mode 100644 index 00000000000000..29786529381d52 --- /dev/null +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_mean_squared_log_error) +def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_log_error`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py new file mode 100644 index 00000000000000..c59d7cf2b8976f --- /dev/null +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2020-07-18 +# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +from typing import Sequence + +import torch +from torchmetrics.functional import bleu_score as _bleu_score + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_bleu_score) +def bleu_score( + translate_corpus: Sequence[str], + reference_corpus: Sequence[str], + n_gram: int = 4, + smooth: bool = False +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.bleu_score`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py new file mode 100644 index 00000000000000..7b6c8641b58295 --- /dev/null +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torchmetrics.functional import precision as _precision +from torchmetrics.functional import precision_recall as _precision_recall +from torchmetrics.functional import recall as _recall + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_precision) +def precision( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.precision`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_recall) +def recall( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_precision_recall) +def precision_recall( + preds: torch.Tensor, + target: torch.Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py new file mode 100644 index 00000000000000..dc9863cbb47c49 --- /dev/null +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +from torchmetrics.functional import precision_recall_curve as _precision_recall_curve + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_precision_recall_curve) +def precision_recall_curve( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], + List[torch.Tensor]], ]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py new file mode 100644 index 00000000000000..51be9d47b91f95 --- /dev/null +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +from torchmetrics.functional import psnr as _psnr + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_psnr) +def psnr( + preds: torch.Tensor, + target: torch.Tensor, + data_range: Optional[float] = None, + base: float = 10.0, + reduction: str = 'elementwise_mean', + dim: Optional[Union[int, Tuple[int, ...]]] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.psnr`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py new file mode 100644 index 00000000000000..fe4b5419893588 --- /dev/null +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics.functional import r2score as _r2score + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_r2score) +def r2score( + preds: torch.Tensor, + target: torch.Tensor, + adjusted: int = 0, + multioutput: str = "uniform_average", +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.r2score`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py new file mode 100644 index 00000000000000..928a0b40fca549 --- /dev/null +++ b/pytorch_lightning/metrics/functional/roc.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Sequence, Tuple, Union + +from torch import Tensor +from torchmetrics.functional import roc as _roc + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_roc) +def roc( + preds: Tensor, + target: Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.roc`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py new file mode 100644 index 00000000000000..65dec211e938a7 --- /dev/null +++ b/pytorch_lightning/metrics/functional/self_supervised.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torchmetrics.functional import embedding_similarity as _embedding_similarity + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_embedding_similarity) +def embedding_similarity( + batch: torch.Tensor, + similarity: str = 'cosine', + reduction: str = 'none', + zero_diagonal: bool = True +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.embedding_similarity`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py new file mode 100644 index 00000000000000..31cff7fcfb9b4b --- /dev/null +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Sequence + +import torch +from torchmetrics.functional import ssim as _ssim + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_ssim) +def ssim( + preds: torch.Tensor, + target: torch.Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.ssim`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py new file mode 100644 index 00000000000000..30c03da237fe60 --- /dev/null +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torchmetrics.functional import stat_scores as _stat_scores + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +@deprecated_metrics(target=_stat_scores) +def stat_scores( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + num_classes: Optional[int] = None, + top_k: Optional[int] = None, + threshold: float = 0.5, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py new file mode 100644 index 00000000000000..ee0fcdb8a92e13 --- /dev/null +++ b/pytorch_lightning/metrics/metric.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from torchmetrics import Metric as _Metric +from torchmetrics.collections import MetricCollection as _MetricCollection + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class Metric(_Metric): + + @deprecated_metrics(target=_Metric) + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + r""" + .. deprecated:: + Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. + """ + + +class MetricCollection(_MetricCollection): + + @deprecated_metrics(target=_MetricCollection) + def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): + """ + .. deprecated:: + Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py new file mode 100644 index 00000000000000..4696dbe57dafd5 --- /dev/null +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401 +from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401 +from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError # noqa: F401 +from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401 +from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401 +from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401 +from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401 diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py new file mode 100644 index 00000000000000..0f94ae2fb37542 --- /dev/null +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import ExplainedVariance as _ExplainedVariance + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class ExplainedVariance(_ExplainedVariance): + + @deprecated_metrics(target=_ExplainedVariance) + def __init__( + self, + multioutput: str = 'uniform_average', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.ExplainedVariance`. + + .. deprecated:: + Use :class:`~torchmetrics.ExplainedVariance`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py new file mode 100644 index 00000000000000..57c7db420445bb --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class MeanAbsoluteError(_MeanAbsoluteError): + + @deprecated_metrics(target=_MeanAbsoluteError) + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.MeanAbsoluteError`. + + .. deprecated:: + Use :class:`~torchmetrics.MeanAbsoluteError`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py new file mode 100644 index 00000000000000..c8e9c151c99d9b --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import MeanSquaredError as _MeanSquaredError + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class MeanSquaredError(_MeanSquaredError): + + @deprecated_metrics(target=_MeanSquaredError) + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.MeanSquaredError`. + + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredError`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py new file mode 100644 index 00000000000000..c8ee8a70691156 --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class MeanSquaredLogError(_MeanSquaredLogError): + + @deprecated_metrics(target=_MeanSquaredLogError) + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.MeanSquaredLogError`. + + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredLogError`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py new file mode 100644 index 00000000000000..f972e9a8e2b5e5 --- /dev/null +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Tuple, Union + +from torchmetrics import PSNR as _PSNR + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class PSNR(_PSNR): + + @deprecated_metrics(target=_PSNR) + def __init__( + self, + data_range: Optional[float] = None, + base: float = 10.0, + reduction: str = 'elementwise_mean', + dim: Optional[Union[int, Tuple[int, ...]]] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.PSNR`. + + .. deprecated:: + Use :class:`~torchmetrics.PSNR`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py new file mode 100644 index 00000000000000..ad5f7f3bd8d070 --- /dev/null +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torchmetrics import R2Score as _R2Score + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class R2Score(_R2Score): + + @deprecated_metrics(target=_R2Score) + def __init__( + self, + num_outputs: int = 1, + adjusted: int = 0, + multioutput: str = "uniform_average", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + """ + This implementation refers to :class:`~torchmetrics.R2Score`. + + .. deprecated:: + Use :class:`~torchmetrics.R2Score`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py new file mode 100644 index 00000000000000..cf5571f3e68f47 --- /dev/null +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence + +from torchmetrics import SSIM as _SSIM + +from pytorch_lightning.metrics.utils import deprecated_metrics + + +class SSIM(_SSIM): + + @deprecated_metrics(target=_SSIM) + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + """ + This implementation refers to :class:`~torchmetrics.SSIM`. + + .. deprecated:: + Use :class:`~torchmetrics.SSIM`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py new file mode 100644 index 00000000000000..4adc88a37ba213 --- /dev/null +++ b/pytorch_lightning/metrics/utils.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +import torch +from deprecate import deprecated +from torchmetrics.utilities.data import dim_zero_cat as _dim_zero_cat +from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean +from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum +from torchmetrics.utilities.data import get_num_classes as _get_num_classes +from torchmetrics.utilities.data import select_topk as _select_topk +from torchmetrics.utilities.data import to_categorical as _to_categorical +from torchmetrics.utilities.data import to_onehot as _to_onehot +from torchmetrics.utilities.distributed import class_reduce as _class_reduce +from torchmetrics.utilities.distributed import reduce as _reduce + +from pytorch_lightning.utilities import rank_zero_deprecation + +deprecated_metrics = partial(deprecated, deprecated_in="1.3.0", remove_in="1.5.0", stream=rank_zero_deprecation) + + +@deprecated_metrics(target=_dim_zero_cat) +def dim_zero_cat(x): + pass + + +@deprecated_metrics(target=_dim_zero_sum) +def dim_zero_sum(x): + pass + + +@deprecated_metrics(target=_dim_zero_mean) +def dim_zero_mean(x): + pass + + +@deprecated_metrics(target=_to_onehot) +def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_onehot`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_select_topk) +def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.select_topk`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_to_categorical) +def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.to_categorical`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_get_num_classes) +def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.data.get_num_classes`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_reduce) +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.reduce`. Will be removed in v1.5.0. + """ + + +@deprecated_metrics(target=_class_reduce) +def class_reduce( + num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" +) -> torch.Tensor: + """ + .. deprecated:: + Use :func:`torchmetrics.utilities.class_reduce`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/overrides/__init__.py b/pytorch_lightning/overrides/__init__.py index e69de29bb2d1d6..ca97a63649389c 100644 --- a/pytorch_lightning/overrides/__init__.py +++ b/pytorch_lightning/overrides/__init__.py @@ -0,0 +1,2 @@ +from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401 +from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401 diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py new file mode 100644 index 00000000000000..0c1ac7b359fd0a --- /dev/null +++ b/pytorch_lightning/overrides/base.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch.nn import DataParallel +from torch.nn.parallel import DistributedDataParallel + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin + + +class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): + + def __init__(self, pl_module: LightningModule): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step`` or ``test_step``. + If the LightningModule is in none of the states `training`, `testing` or `validation`, + the inputs will be redirected to the + :meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method. + Inheriting classes may also modify the inputs or outputs of forward. + + Args: + pl_module: the model to wrap + """ + super().__init__() + self.module = pl_module + + def forward(self, *inputs, **kwargs): + trainer = self.module.trainer + + if trainer and trainer.training: + output = self.module.training_step(*inputs, **kwargs) + + # In manual_optimization, we need to prevent DDP reducer as + # it is done manually in ``LightningModule.manual_backward`` + # `require_backward_grad_sync` will be reset in the + # ddp_plugin ``post_training_step`` hook + if not self.module.automatic_optimization: + trainer.model.require_backward_grad_sync = False + elif trainer and trainer.testing: + output = self.module.test_step(*inputs, **kwargs) + elif trainer and (trainer.sanity_checking or trainer.validating): + output = self.module.validation_step(*inputs, **kwargs) + elif trainer and trainer.predicting: + output = self.module.predict_step(*inputs, **kwargs) + else: + output = self.module(*inputs, **kwargs) + + return output + + def on_post_move_to_device(self): + pass + + +def unwrap_lightning_module(wrapped_model) -> LightningModule: + model = wrapped_model + if isinstance(model, (DistributedDataParallel, DataParallel)): + model = model.module + if isinstance(model, _LightningModuleWrapperBase): + model = model.module + return model diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index b2f7816ec0ac20..b027502f99e8ac 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -1,226 +1,100 @@ -import itertools -import threading -from itertools import chain +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +import warnings +from typing import Any import torch -from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel - -def _find_tensors(obj): # pragma: no-cover - r""" - Recursively find all tensors contained in the specified object. - """ - if isinstance(obj, torch.Tensor): - return [obj] - if isinstance(obj, (list, tuple)): - return itertools.chain(*map(_find_tensors, obj)) - if isinstance(obj, dict): - return itertools.chain(*map(_find_tensors, obj.values())) - return [] - - -def get_a_var(obj): # pragma: no-cover - if isinstance(obj, torch.Tensor): - return obj - - if isinstance(obj, (list, tuple)): - for result in map(get_a_var, obj): - if isinstance(result, torch.Tensor): - return result - if isinstance(obj, dict): - for result in map(get_a_var, obj.items()): - if isinstance(result, torch.Tensor): - return result - return None +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.distributed import LightningDistributedModule +from pytorch_lightning.utilities.apply_func import apply_to_collection class LightningDataParallel(DataParallel): - """ - Override the forward call in lightning so it goes to training and validation step respectively - """ - def forward(self, *inputs, **kwargs): - if not self.device_ids: - return self.module(*inputs, **kwargs) + def __init__(self, module: LightningModule, *args, **kwargs): + warnings.warn( + "The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4." + " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning + ) + super().__init__(LightningParallelModule(module), *args, **kwargs) - for t in chain(self.module.parameters(), self.module.buffers()): - if t.device != self.src_device_obj: - raise RuntimeError("module must have its parameters and buffers " - "on device {} (device_ids[0]) but found one of " - "them on device: {}".format(self.src_device_obj, t.device)) - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - # lightning - if self.module.training: - return self.module.training_step(*inputs[0], **kwargs[0]) - if self.module.testing: - return self.module.test_step(*inputs[0], **kwargs[0]) +class LightningDistributedDataParallel(DistributedDataParallel): - return self.module.validation_step(*inputs[0], **kwargs[0]) + def __init__(self, module: LightningModule, *args, **kwargs): + warnings.warn( + "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." + " From now on we recommend to directly subclass `torch.nn.parallel.DistributedDataParallel`.", + DeprecationWarning + ) + super().__init__(LightningDistributedModule(module), *args, **kwargs) - replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) - outputs = self.parallel_apply(replicas, inputs, kwargs) - return self.gather(outputs, self.output_device) - def parallel_apply(self, replicas, inputs, kwargs): - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) +class LightningParallelModule(_LightningModuleWrapperBase): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. + This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as + shown in the example. It also takes care of converting Python scalars to Tensors and + un-squeezes 0-dimensional Tensors as it is required by :class:`~torch.nn.parallel.DataParallel`. + Example: + + dp_model = torch.nn.DataParallel( + module=LightningParallelModule(lightning_module), + device_ids=[3, 4], + ... + ) + + Args: + pl_module: the model to wrap -class LightningDistributedDataParallel(DistributedDataParallel): - """ - Override the forward call in lightning so it goes to training and validation step respectively """ - def parallel_apply(self, replicas, inputs, kwargs): - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) - - def forward(self, *inputs, **kwargs): # pragma: no-cover - self._sync_params() - if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - # -------------- - # LIGHTNING MOD - # -------------- - # normal - # output = self.module(*inputs[0], **kwargs[0]) - # lightning - if self.module.training: - output = self.module.training_step(*inputs[0], **kwargs[0]) - elif self.module.testing: - output = self.module.test_step(*inputs[0], **kwargs[0]) - else: - output = self.module.validation_step(*inputs[0], **kwargs[0]) - else: - outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) - output = self.gather(outputs, self.output_device) - else: - # normal - # output = self.module(*inputs, **kwargs) - # lightning (ddp_cpu) - if self.module.training: - output = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: - output = self.module.test_step(*inputs, **kwargs) - else: - output = self.module.validation_step(*inputs, **kwargs) - - if torch.is_grad_enabled(): - # We'll return the output object verbatim since it is a freeform - # object. We need to find any tensors in this object, though, - # because we need to figure out which parameters were used during - # this forward pass, to ensure we short circuit reduction for any - # unused parameters. Only if `find_unused_parameters` is set. - if self.find_unused_parameters: - self.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - self.reducer.prepare_for_backward([]) + def __init__(self, pl_module: LightningModule): + super().__init__(pl_module) + + def forward(self, *inputs, **kwargs): + output = super().forward(*inputs, **kwargs) + + def output_transform(data: Any): + data = python_scalar_to_tensor(data, self.module.device) + data = unsqueeze_scalar_tensor(data) + return data + + output = apply_to_collection( + output, + dtype=(numbers.Number, torch.Tensor), + function=output_transform, + ) return output -def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no-cover - r"""Applies each `module` in :attr:`modules` in parallel on arguments - contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) - on each of :attr:`devices`. +def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any: + """ Converts a Python scalar number to a torch tensor and places it on the given device. """ + if isinstance(data, numbers.Number): + data = torch.tensor([data], device=device) + return data - Args: - modules (Module): modules to be parallelized - inputs (tensor): inputs to the modules - devices (list of int or torch.device): CUDA devices - - :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and - :attr:`devices` (if given) should all have same length. Moreover, each - element of :attr:`inputs` can either be a single object as the only argument - to a module, or a collection of positional arguments. - """ - assert len(modules) == len(inputs) - if kwargs_tup is not None: - assert len(modules) == len(kwargs_tup) - else: - kwargs_tup = ({},) * len(modules) - if devices is not None: - assert len(modules) == len(devices) - else: - devices = [None] * len(modules) - devices = list(map(lambda x: _get_device_index(x, True), devices)) - lock = threading.Lock() - results = {} - grad_enabled = torch.is_grad_enabled() - - def _worker(i, module, input, kwargs, device=None): - torch.set_grad_enabled(grad_enabled) - if device is None: - device = get_a_var(input).get_device() - try: - with torch.cuda.device(device): - # this also avoids accidental slicing of `input` if it is a Tensor - if not isinstance(input, (list, tuple)): - input = (input,) - - # --------------- - # CHANGE - if module.training: - output = module.training_step(*input, **kwargs) - - elif module.testing: - output = module.test_step(*input, **kwargs) - - else: - output = module.validation_step(*input, **kwargs) - - if module.use_dp or module.use_ddp2: - auto_squeeze_dim_zeros(output) - # --------------- - - with lock: - results[i] = output - except Exception as e: - with lock: - results[i] = e - - # TODO: fix hack (maybe not a hack) - # make sure each module knows what training state it's in... - # fixes weird bug where copies are out of sync - root_m = modules[0] - for m in modules[1:]: - m.training = root_m.training - m.testing = root_m.testing - - if len(modules) > 1: - threads = [threading.Thread(target=_worker, - args=(i, module, input, kwargs, device)) - for i, (module, input, kwargs, device) in - enumerate(zip(modules, inputs, kwargs_tup, devices))] - - for thread in threads: - thread.start() - for thread in threads: - thread.join() - else: - _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) - - outputs = [] - for i in range(len(inputs)): - output = results[i] - if isinstance(output, Exception): - raise output - outputs.append(output) - return outputs - - -def auto_squeeze_dim_zeros(output): - """ - In DP or DDP2 we need to unsqueeze dim 0 - :param output: - :return: - """ - for k, v in output.items(): - if not isinstance(v, torch.Tensor): - continue - is_scalar = v.dim() == 0 - if is_scalar: - output[k] = output[k].unsqueeze(0) +def unsqueeze_scalar_tensor(data: Any) -> Any: + """ Un-squeezes a 0-dim tensor. """ + if isinstance(data, torch.Tensor) and data.dim() == 0: + data = data.unsqueeze(0) + return data diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py new file mode 100644 index 00000000000000..c934e422a4308b --- /dev/null +++ b/pytorch_lightning/overrides/distributed.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Any + +import torch +from torch.nn.parallel import DistributedDataParallel + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase + + +class LightningDistributedModule(_LightningModuleWrapperBase): + + def __init__(self, pl_module: LightningModule): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. + This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as + shown in the example. + + Example: + + ddp_model = torch.nn.parallel.DistributedDataParallel( + module=LightningDistributedModule(lightning_module), + device_ids=[local_rank], + ... + ) + + Args: + pl_module: the model to wrap + + """ + super().__init__(pl_module) + + +def _find_tensors(obj): # pragma: no-cover + r""" + Recursively find all tensors contained in the specified object. + """ + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(_find_tensors, obj.values())) + return [] + + +# In manual_optimization, we need to call reducer prepare_for_backward. +# Note: Keep track of Pytorch DDP and update if there is a change +# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 +def prepare_for_backward(model: DistributedDataParallel, output: Any): + if torch.is_grad_enabled() and model.require_backward_grad_sync: + model.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if model.find_unused_parameters: + model.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + model.reducer.prepare_for_backward([]) + else: + model.require_forward_param_sync = False diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py new file mode 100644 index 00000000000000..f7c3b8d5fd5755 --- /dev/null +++ b/pytorch_lightning/overrides/fairscale.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE + +LightningShardedDataParallel = None +if _FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + + class LightningShardedDataParallel(_LightningModuleWrapperBase): + # Just do this for later docstrings + pass + + def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: + model = wrapped_model + if isinstance(model, ShardedDataParallel): + model = model.module + + return unwrap_lightning_module(model) diff --git a/pytorch_lightning/overrides/override_data_parallel.py b/pytorch_lightning/overrides/override_data_parallel.py deleted file mode 100644 index bf08b1a528953b..00000000000000 --- a/pytorch_lightning/overrides/override_data_parallel.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -.. warning:: `override_data_parallel` module has been renamed to `data_parallel` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`override_data_parallel` module has been renamed to `data_parallel` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.overrides.data_parallel import ( # noqa: E402 - get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel) diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py new file mode 100644 index 00000000000000..67b64c046dc188 --- /dev/null +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -0,0 +1,94 @@ +import logging +import pickle + +import torch + +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 + +log = logging.getLogger(__name__) + +if torch.distributed.is_available(): + from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember + +# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` +# and enable broadcasting for PyTorch 1.6 and lower. + + +# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 +def _rank_not_in_group(group): + """ + Helper that checks if the current process's rank is not in a given group. + """ + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 +def _object_to_tensor(obj): + buffer = pickle.dumps(obj) + byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py +def _tensor_to_object(tensor, tensor_size): + buf = tensor.numpy().tobytes()[:tensor_size] + out = pickle.loads(buf) + return out + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 +def _broadcast_object_list(object_list, src=0, group=None): + if _rank_not_in_group(group): + return + + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.LongTensor(len(object_list)) + + group_backend = get_backend(group) + is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + broadcast(object_tensor, src=src, group=group) + + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): + from torch.distributed.distributed_c10d import broadcast_object_list +else: + broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py new file mode 100644 index 00000000000000..a67235baa47679 --- /dev/null +++ b/pytorch_lightning/plugins/__init__.py @@ -0,0 +1,49 @@ +from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 + +__all__ = [ + "ApexMixedPrecisionPlugin", + "DataParallelPlugin", + "DDP2Plugin", + "DDPPlugin", + "DDPSpawnPlugin", + "DeepSpeedPlugin", + "DeepSpeedPrecisionPlugin", + "DoublePrecisionPlugin", + "HorovodPlugin", + "NativeMixedPrecisionPlugin", + "PrecisionPlugin", + "ShardedNativeMixedPrecisionPlugin", + "SingleDevicePlugin", + "SingleTPUPlugin", + "TPUHalfPrecisionPlugin", + "TPUSpawnPlugin", + 'RPCPlugin', + 'RPCSequentialPlugin', + 'TrainingTypePlugin', + 'ParallelPlugin', + 'Plugin', + 'DDPShardedPlugin', + 'DDPSpawnShardedPlugin', +] diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py new file mode 100644 index 00000000000000..f89bdf7a8aa723 --- /dev/null +++ b/pytorch_lightning/plugins/base_plugin.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from abc import ABC +from typing import Generator + + +class Plugin(ABC): + """Basic Plugin class to derive precision and training type plugins from.""" + + def pre_dispatch(self) -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + + def post_dispatch(self) -> None: + """Hook to do something after the training/evaluation/prediction finishes.""" + + @contextlib.contextmanager + def train_step_context(self) -> Generator: + """A contextmanager for the trainstep""" + yield + + @contextlib.contextmanager + def val_step_context(self) -> Generator: + """A contextmanager for the validation step""" + yield + + @contextlib.contextmanager + def test_step_context(self) -> Generator: + """A contextmanager for the teststep""" + yield + + @contextlib.contextmanager + def predict_context(self) -> Generator: + """A contextmanager for the predict step""" + yield diff --git a/pytorch_lightning/plugins/environments/__init__.py b/pytorch_lightning/plugins/environments/__init__.py new file mode 100644 index 00000000000000..70c1f8da90f13d --- /dev/null +++ b/pytorch_lightning/plugins/environments/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py new file mode 100644 index 00000000000000..f3fb2fbeabaa2b --- /dev/null +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Optional + + +class ClusterEnvironment(ABC): + """ Specification of a cluster environment. """ + + @abstractmethod + def creates_children(self) -> bool: + """ Whether the environment creates the subprocesses or not. """ + + @abstractmethod + def master_address(self) -> str: + """ The master address through which all processes connect and communicate. """ + + @abstractmethod + def master_port(self) -> int: + """ An open and configured port in the master node through which all processes communicate. """ + + @abstractmethod + def world_size(self) -> Optional[int]: + """ The number of processes across all devices and nodes. """ + + @abstractmethod + def local_rank(self) -> int: + """ The rank (index) of the currently running process inside of the current node. """ + + @abstractmethod + def node_rank(self) -> int: + """ The rank (index) of the node on which the current process runs. """ diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py new file mode 100644 index 00000000000000..6b71122b065bf2 --- /dev/null +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -0,0 +1,71 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket +from typing import Optional + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment + + +class LightningEnvironment(ClusterEnvironment): + """ + The default environment used by Lightning for a single node or free cluster (not managed). + + The master process must be launched by the user and Lightning will spawn new + worker processes for distributed training, either in a single node or across multiple nodes. + + If the master address and port are not provided, the default environment will choose them + automatically. It is recommended to use this default environment for single-node distributed + training as it provides the most convenient way to launch the training script. + """ + + def __init__(self): + super().__init__() + self._master_port = None + + def creates_children(self) -> bool: + return False + + def master_address(self) -> str: + return os.environ.get("MASTER_ADDR", "127.0.0.1") + + def master_port(self) -> int: + if self._master_port is None: + self._master_port = os.environ.get("MASTER_PORT", find_free_network_port()) + return int(self._master_port) + + def world_size(self) -> Optional[int]: + return None + + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + def node_rank(self) -> int: + group_rank = os.environ.get("GROUP_RANK", 0) + return int(os.environ.get("NODE_RANK", group_rank)) + + +def find_free_network_port() -> int: + """ + Finds a free port on localhost. + It is useful in single-node training when we don't want to connect to a real master node but + have to set the `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + s.close() + return port diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py new file mode 100644 index 00000000000000..3cba5d101a1598 --- /dev/null +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import re + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment + +log = logging.getLogger(__name__) + + +class SLURMEnvironment(ClusterEnvironment): + + def __init__(self): + super().__init__() + + def creates_children(self) -> bool: + return True + + def master_address(self) -> str: + # figure out the root node addr + slurm_nodelist = os.environ.get("SLURM_NODELIST") + if slurm_nodelist: + root_node = slurm_nodelist.split(" ")[0].split(",")[0] + else: + root_node = "127.0.0.1" + + root_node = self.resolve_root_node_address(root_node) + os.environ["MASTER_ADDR"] = root_node + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + return root_node + + def master_port(self) -> int: + # ----------------------- + # SLURM JOB = PORT number + # ----------------------- + # this way every process knows what port to use + default_port = os.environ.get("SLURM_JOB_ID") + if default_port: + # use the last 4 numbers in the job id as the id + default_port = default_port[-4:] + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + else: + default_port = 12910 + + # ----------------------- + # PORT NUMBER = MASTER_PORT + # ----------------------- + # in case the user passed it in + if "MASTER_PORT" in os.environ: + default_port = os.environ["MASTER_PORT"] + else: + os.environ["MASTER_PORT"] = str(default_port) + + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + return int(default_port) + + def world_size(self): + return None + + def local_rank(self) -> int: + return int(os.environ['SLURM_LOCALID']) + + def node_rank(self) -> int: + return int(os.environ['SLURM_NODEID']) + + def resolve_root_node_address(self, root_node: str) -> str: + if '[' in root_node: + name, numbers = root_node.split('[', maxsplit=1) + number = numbers.split(',', maxsplit=1)[0] + if '-' in number: + number = number.split('-')[0] + + number = re.sub('[^0-9]', '', number) + root_node = name + number + + return root_node diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py new file mode 100644 index 00000000000000..c3a59fbfd75bc1 --- /dev/null +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Optional + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import rank_zero_warn + +log = logging.getLogger(__name__) + + +class TorchElasticEnvironment(ClusterEnvironment): + + def __init__(self): + super().__init__() + + def creates_children(self) -> bool: + return True + + def master_address(self) -> str: + if "MASTER_ADDR" not in os.environ: + rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") + os.environ["MASTER_ADDR"] = "127.0.0.1" + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + master_address = os.environ.get('MASTER_ADDR') + return master_address + + def master_port(self) -> int: + if "MASTER_PORT" not in os.environ: + rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") + os.environ["MASTER_PORT"] = "12910" + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + port = int(os.environ.get('MASTER_PORT')) + return port + + def world_size(self) -> Optional[int]: + world_size = os.environ.get('WORLD_SIZE') + return int(world_size) if world_size is not None else world_size + + def local_rank(self) -> int: + return int(os.environ['LOCAL_RANK']) + + def node_rank(self) -> int: + return int(os.environ.get('GROUP_RANK', 0)) diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py new file mode 100644 index 00000000000000..d32aac829a13d8 --- /dev/null +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -0,0 +1,8 @@ +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py new file mode 100644 index 00000000000000..b600eca5e6bc29 --- /dev/null +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -0,0 +1,172 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING + +import torch + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn + +if _APEX_AVAILABLE: + from apex import amp + +if TYPE_CHECKING: + from torch.optim import Optimizer + + +class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): + """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" + + def __init__(self, amp_level: str = "O2") -> None: + self.backend = AMPType.APEX + self.amp_level = amp_level + + def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]: + return amp.master_params(optimizer) + + def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]: + """Connects the precision plugin to the training process, + configures apex and reinits the schedulers + """ + if model.device.type != "cuda": + return model, optimizers, lr_schedulers + model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level) + self.reinit_scheduler_properties(optimizers, lr_schedulers) + return model, optimizers, lr_schedulers + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: 'Optimizer', + opt_idx: int, + should_accumulate: bool, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer) + + # enter apex context + context = closure_loss + closure_loss = closure_loss.__enter__() + + # do backward pass + # TODO: not entirely sure, why we need this + if model is not None and isinstance(model, LightningModule): + model.backward(closure_loss, optimizer, opt_idx, **kwargs) + + # TODO: avoid dev_debugger and track these calls with mock + model.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX)) + + else: + closure_loss.backward(*args, **kwargs) + + # exit amp context + a, b, c = None, None, None + error = context.__exit__(a, b, c) + if error: + rank_zero_warn(a, b, c) + raise Exception("apex unscale error") + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + return closure_loss + + def configure_apex( + self, + amp: Type, + model: LightningModule, + optimizers: List['Optimizer'], + amp_level: str, + ) -> Tuple[LightningModule, List['Optimizer']]: + r""" + Override to init AMP your own way. + Must return a model and list of optimizers. + + Args: + amp: pointer to amp library object. + model: pointer to current :class:`LightningModule`. + optimizers: list of optimizers passed in :meth:`configure_optimizers`. + amp_level: AMP mode chosen ('O1', 'O2', etc...) + + Return: + Apex wrapped model and optimizers + + Examples: + .. code-block:: python + + # Default implementation used by Trainer. + def configure_apex(self, amp, model, optimizers, amp_level): + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers + """ + model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) + return model, optimizers + + @staticmethod + def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None: + """Reinitializes schedulers with correct properties""" + # Reinitialize optimizer.step properties added by schedulers + for scheduler in schedulers: + scheduler = scheduler['scheduler'] + state = None + + for optimizer in optimizers: + # check that we dont mix users optimizers and schedulers + if scheduler.optimizer == optimizer: + # Find the mro belonging to the base lr scheduler class + for i, mro in enumerate(scheduler.__class__.__mro__): + if mro in (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + state = scheduler.state_dict() + scheduler.__class__.__mro__[i].__init__(scheduler, optimizer) + scheduler.load_state_dict(state) + break + + if state is not None: + break + + def pre_optimizer_step( + self, + pl_module: LightningModule, + optimizer: 'Optimizer', + optimizer_idx: int, + lambda_closure: Callable, + **kwargs: Any, + ) -> bool: + """ + always called before the optimizer step. + """ + # apex amp does not support closures. + lambda_closure() + + if not pl_module.automatic_optimization: + pl_module.trainer.call_hook("on_after_backward") + + optimizer.step(**kwargs) + return False diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py new file mode 100644 index 00000000000000..6bcbb5ad851dc6 --- /dev/null +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, TYPE_CHECKING, Union + +import torch + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +if TYPE_CHECKING: + from torch.optim import Optimizer + + from pytorch_lightning.core.lightning import LightningModule + +warning_cache = WarningCache() + + +class DeepSpeedPrecisionPlugin(PrecisionPlugin): + + def __init__(self, precision: int) -> None: + super().__init__() + self.precision = precision + + def pre_optimizer_step( + self, + pl_module: 'LightningModule', + optimizer: 'Optimizer', + optimizer_idx: int, + lambda_closure: Callable, + **kwargs: Any, + ) -> bool: + deepspeed_engine = pl_module.trainer.model + # DeepSpeed not support closures. + lambda_closure() + + if not pl_module.automatic_optimization: + pl_module.trainer.call_hook("on_after_backward") + + deepspeed_engine.step() + + return False + + def backward( + self, + model: 'LightningModule', + closure_loss: torch.Tensor, + optimizer: 'Optimizer', + opt_idx: int, + should_accumulate: bool, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + if is_overridden('backward', model): + warning_cache.warn( + "Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" + "backward logic outside of the LightningModule" + ) + # todo: hack around for deepspeed engine to call backward + deepspeed_engine = model.trainer.model + deepspeed_engine.backward(closure_loss, *args, **kwargs) + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + + return closure_loss + + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + """ + DeepSpeed handles clipping gradients via the training type plugin. + """ + pass diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py new file mode 100644 index 00000000000000..4720f0f874fd09 --- /dev/null +++ b/pytorch_lightning/plugins/precision/double.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import wraps +from typing import Any, Sequence, Tuple, TYPE_CHECKING, List + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection + +if TYPE_CHECKING: + from torch.nn import Module + from torch.optim import Optimizer + + +class _DoublePrecisionPatch: + """Class to handle patching of methods in the ``LightningModule`` and subsequent teardown.""" + + def __init__(self, model: 'Module', method_name: str, old_method: Any) -> None: + self.model = model + self.method_name = method_name + self.old_method = old_method + + def teardown(self) -> None: + setattr(self.model, self.method_name, self.old_method) + + @staticmethod + def _to_double_precision(data: torch.Tensor) -> torch.Tensor: + if data.is_floating_point(): + return data.double() + return data + + @staticmethod + def _move_float_tensors_to_double(collection: Any) -> Any: + return apply_to_collection( + collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision + ) + + @classmethod + def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch': + old_method = getattr(model, method_name) + + @wraps(old_method) + def new_method(*args: Any, **kwargs: Any) -> Any: + return old_method( + *_DoublePrecisionPatch._move_float_tensors_to_double(args), + **_DoublePrecisionPatch._move_float_tensors_to_double(kwargs) + ) + + setattr(model, method_name, new_method if callable(old_method) else old_method) + return cls(model, method_name, old_method) + + +class DoublePrecisionPlugin(PrecisionPlugin): + """Plugin for training with double (``torch.float64``) precision.""" + + precision: int = 64 + + def __init__(self) -> None: + self.patches: List[_DoublePrecisionPatch] = [] + + def connect( + self, + model: 'Module', + optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any], + ) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]: + """Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`, + `predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter + `optimizers` or `lr_schedulers`.""" + model = model.to(dtype=torch.float64) + if isinstance(model, LightningModule): + self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step')) + self.patches.append(_DoublePrecisionPatch.patch(model, 'forward')) + + return super().connect(model, optimizers, lr_schedulers) + + def post_dispatch(self) -> None: + while len(self.patches) > 0: + self.patches.pop().teardown() diff --git a/pytorch_lightning/plugins/precision/mixed.py b/pytorch_lightning/plugins/precision/mixed.py new file mode 100644 index 00000000000000..1a84b5eae96689 --- /dev/null +++ b/pytorch_lightning/plugins/precision/mixed.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Union + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + +if TYPE_CHECKING: + from pytorch_lightning.utilities import AMPType + + +class MixedPrecisionPlugin(PrecisionPlugin): + """Base Class for mixed precision""" + + EPSILON: float = 1e-5 + backend: 'AMPType' + precision: Union[str, int] = "mixed" diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py new file mode 100644 index 00000000000000..3c83945c8a1b77 --- /dev/null +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -0,0 +1,123 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Any, Callable, Generator, TYPE_CHECKING + +import torch +from torch.optim import LBFGS + +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if TYPE_CHECKING: + from torch.optim import Optimizer + + from pytorch_lightning.core import LightningModule + + +class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + + def __init__(self) -> None: + if not _NATIVE_AMP_AVAILABLE: + raise MisconfigurationException( + "You have asked for native AMP but your PyTorch version does not support it." + " Consider upgrading with `pip install torch>=1.6`." + ) + + self.backend = AMPType.NATIVE + self.scaler = torch.cuda.amp.GradScaler() + + def backward( + self, + model: 'LightningModule', + closure_loss: torch.Tensor, + optimizer: 'Optimizer', + opt_idx: int, + should_accumulate: bool, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + closure_loss = self.scaler.scale(closure_loss) + + closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) + + # unscale gradient to allow analyze within `on_after_backward` + if not should_accumulate and model.automatic_optimization: + self.scaler.unscale_(optimizer) + + return closure_loss + + def pre_optimizer_step( + self, + pl_module: 'LightningModule', + optimizer: 'Optimizer', + optimizer_idx: int, + lambda_closure: Callable, + **kwargs: Any, + ) -> bool: + """always called before the optimizer step. + Checks that the optimizer is not LBFGS, as this one is not supported by native amp + """ + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." + " To request, please file a Github issue in PyTorch and tag @mcarilli" + ) + lambda_closure() + + if not pl_module.automatic_optimization: + self.scaler.unscale_(optimizer) + pl_module.trainer.call_hook("on_after_backward") + + return False + + def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> None: + """Updates the GradScaler""" + self.scaler.step(optimizer) + self.scaler.update() + + @contextmanager + def train_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield + + @contextmanager + def predict_context(self) -> Generator[None, None, None]: + """Enable autocast context""" + with torch.cuda.amp.autocast(): + yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py new file mode 100644 index 00000000000000..7172d82391bd3b --- /dev/null +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -0,0 +1,134 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Callable, Generator, Sequence, Tuple, TYPE_CHECKING, Union + +import torch + +from pytorch_lightning.plugins.base_plugin import Plugin + +if TYPE_CHECKING: + from torch.nn import Module + from torch.optim import Optimizer + + from pytorch_lightning.core import LightningModule + + +class PrecisionPlugin(Plugin): + """ Plugin handling the precision-specific parts of the training. + The static classattributes EPSILON and precision must be overwritten in child-classes and their + default values reflect fp32 training. + """ + EPSILON: float = 1e-6 + precision: Union[str, int] = 32 + + def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]: + """The master params of the model. Returns the plain model params here. + Maybe different in other precision plugins. + + """ + for group in optimizer.param_groups: + for p in group["params"]: + yield p + + def connect( + self, + model: 'Module', + optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any], + ) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]: + """Connects this plugin to the accelerator and the training process""" + return model, optimizers, lr_schedulers + + def backward( + self, + model: 'LightningModule', + closure_loss: torch.Tensor, + optimizer: 'Optimizer', + opt_idx: int, + should_accumulate: bool, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + automatic_optimization = model.automatic_optimization + + # do backward pass + if automatic_optimization: + model.backward(closure_loss, optimizer, opt_idx) + else: + closure_loss.backward(*args, **kwargs) + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + + return closure_loss + + def pre_optimizer_step( + self, + pl_module: 'LightningModule', + optimizer: 'Optimizer', + optimizer_idx: int, + lambda_closure: Callable, + **kwargs: Any, + ) -> bool: + """Hook to do something before each optimizer step.""" + return True + + def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> None: + """Hook to do something after each optimizer step.""" + + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + """Clips the gradients to a specific value""" + if clip_val is None: + return + + grad_clip_val = float(clip_val) + + if grad_clip_val <= 0: + return + + parameters = list(self.master_params(optimizer)) + + max_norm = grad_clip_val + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + device = parameters[0].device + + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + eps = self.EPSILON + + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py new file mode 100644 index 00000000000000..39dc01f97df112 --- /dev/null +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import cast, TYPE_CHECKING, Union + +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE + +if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + +if TYPE_CHECKING: + from torch.optim import Optimizer + + +class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + """Mixed Precision for Sharded Training + """ + + def __init__(self) -> None: + super().__init__() + self.scaler = ShardedGradScaler() + + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + if clip_val <= 0: + return + + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py new file mode 100644 index 00000000000000..e7d7507d8257b1 --- /dev/null +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Sequence, Tuple, TYPE_CHECKING + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + +if TYPE_CHECKING: + from torch.nn import Module + from torch.optim import Optimizer + + +class TPUHalfPrecisionPlugin(PrecisionPlugin): + """Plugin that enables bfloats on TPUs""" + + precision: int = 16 + + def connect( + self, + model: 'Module', + optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any], + ) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]: + os.environ["XLA_USE_BF16"] = str(1) + return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py new file mode 100644 index 00000000000000..30723d67da3f41 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -0,0 +1,15 @@ +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py new file mode 100644 index 00000000000000..58e26e7db32d85 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -0,0 +1,306 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import subprocess +import sys +from time import sleep +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.distributed as torch_distrib +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import Optimizer + +from pytorch_lightning.distributed import LightningDistributed +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import seed_everything + +if _HYDRA_AVAILABLE: + from hydra.core.hydra_config import HydraConfig + from hydra.utils import get_original_cwd, to_absolute_path + +log = logging.getLogger(__name__) + + +class DDPPlugin(ParallelPlugin): + """ + Plugin for multi-process single-device training on one or multiple nodes. + + The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, + where N is the number of devices (e.g. GPU) per node. + It is very similar to how :mod:`torch.distributed.launch` launches processes. + """ + + distributed_backend = "ddp" + + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + num_nodes: int = 1, + cluster_environment: ClusterEnvironment = None, + sync_batchnorm: bool = False, + **kwargs: Union[Any, Dict[str, Any]], + ) -> None: + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + self.interactive_ddp_procs = [] + self.num_nodes = num_nodes + self.sync_batchnorm = sync_batchnorm + self.dist = LightningDistributed() + self._ddp_kwargs = kwargs + self._has_spawned_children = False + self.task_idx = None + self.node_rank = 0 + self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + return distributed_sampler_kwargs + + def setup_environment(self): + # start the other scripts + if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": + self._call_children_scripts() + + # set the task idx + self.task_idx = self.cluster_environment.local_rank() + + self.setup_distributed() + + def _call_children_scripts(self): + + # bookkeeping of spawned processes + assert self.global_rank == 0 + self._check_can_spawn_children() + self._has_spawned_children = True + + # DDP Environment variables + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + + # allow the user to pass the node rank + os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) + + # when user is using hydra find the absolute path + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path + + # pull out the commands used to run the script and resolve the abs file path + command = sys.argv + try: + full_path = path_lib(command[0]) + except Exception: + full_path = os.path.abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + + # the visible devices tell us how many GPUs we want to use. + # when the trainer script was called the device has already been scoped by the time + # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone + # but forward the GPUs selected via environment variables + if self.parallel_devices is None: + raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") + + os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) + os.environ["PL_IN_DDP_SUBPROCESS"] = "1" + + if self.lightning_module.logger is not None: + os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) + + num_gpus = len(self.parallel_devices) + os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" + + self.interactive_ddp_procs = [] + + for local_rank in range(1, self.num_processes): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + + # remove env var if global seed not set + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + del env_copy["PL_GLOBAL_SEED"] + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if _HYDRA_AVAILABLE: + if HydraConfig.initialized(): + cwd = get_original_cwd() + os_cwd = f'"{os.getcwd()}"' + command += [f'hydra.run.dir={os_cwd}', f'hydra.job.name=train_ddp_process_{local_rank}'] + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) + self.interactive_ddp_procs.append(proc) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + def setup_distributed(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + + def _check_can_spawn_children(self): + if self._has_spawned_children: + raise RuntimeError( + "You tried to run `.fit` or `.test` multiple times in the same script." + " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." + ) + + def set_world_ranks(self): + self.local_rank = self.task_idx + self.node_rank = self.cluster_environment.node_rank() + self.global_rank = self.node_rank * self.num_processes + self.local_rank + self.world_size = self.num_nodes * self.num_processes + + def pre_configure_ddp(self): + # if unset, default `find_unused_parameters` `True` + # Many models require setting this parameter to True, as there are corner cases + # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. + # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization + if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( + "find_unused_parameters", False + ): + rank_zero_warn( + "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " + "to properly work with DDP." + ) + self._ddp_kwargs["find_unused_parameters"] = True + + def configure_ddp(self): + self.pre_configure_ddp() + self._model = DistributedDataParallel( + LightningDistributedModule(self.model), + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + + def determine_ddp_device_ids(self): + if self.root_device.type == "cpu": + return None + return [self.root_device.index] + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) + + if not torch.distributed.is_initialized(): + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + + def pre_dispatch(self): + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # move the model to the correct device + self.model_to_device() + + self.configure_ddp() + + self.barrier() + + def post_dispatch(self): + if "WORLD_SIZE" in os.environ: + del os.environ["WORLD_SIZE"] + + def barrier(self, *args, **kwargs): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + return self.dist.broadcast(obj) + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + """Run before precision plugin executes backward""" + if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: + prepare_for_backward(self.model, closure_loss) + + def model_to_device(self): + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor + + def training_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def post_training_step(self): + if not self.lightning_module.automatic_optimization: + self.model.require_backward_grad_sync = True diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py new file mode 100644 index 00000000000000..a94bb5459bb1e5 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin + + +class DDP2Plugin(DDPPlugin): + + def setup(self, model): + self._model = model + # set the task idx + self.task_idx = self.cluster_environment.local_rank() + # the difference to DDP is that we don't call children processes here + + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all processes to one aggregated tensor. + In DDP2, the reduction here is only across local devices within the node. + + Args: + tensor: the tensor to sync and reduce + *args: ignored for DDP2 + **kwargs: ignored for DDP2 + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + elif isinstance(tensor, torch.Tensor): + tensor = tensor.mean() + + return tensor + + @property + def root_device(self): + return self.parallel_devices[0] + + def model_to_device(self): + # no need to do anything when model is wrapped in torch.nn.DataParallel + pass + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank) + return distributed_sampler_kwargs + + def set_world_ranks(self): + self.local_rank = self.task_idx + self.node_rank = self.cluster_environment.node_rank() + self.global_rank = self.node_rank + self.world_size = self.num_nodes diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py new file mode 100644 index 00000000000000..87d7fa5faecac5 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -0,0 +1,290 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import re +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as torch_distrib +import torch.multiprocessing as mp +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import Optimizer + +from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.seed import seed_everything + +log = logging.getLogger(__name__) + + +class DDPSpawnPlugin(ParallelPlugin): + + distributed_backend = "ddp_spawn" + + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + num_nodes: int = 1, + cluster_environment: ClusterEnvironment = None, + sync_batchnorm: bool = False, + **kwargs: Union[Any, Dict[str, Any]], + ): + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + self.num_nodes = num_nodes + self.sync_batchnorm = sync_batchnorm + self._ddp_kwargs = kwargs + self.dist = LightningDistributed() + self.num_processes = len(parallel_devices) + self.node_rank = 0 + self.mp_queue = None + + def __getstate__(self): + """ Makes this plugin pickleable without destroying the queue in the current process. """ + state = self.__dict__.copy() + state["mp_queue"] = None + return state + + def __setstate__(self, state): + self.__dict__ = state + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + return distributed_sampler_kwargs + + def setup(self, model): + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + + # pass in a state q + smp = mp.get_context("spawn") + self.mp_queue = smp.SimpleQueue() + + def set_world_ranks(self, process_idx): + self.local_rank = process_idx + self.node_rank = self.cluster_environment.node_rank() + self.task_idx = self.cluster_environment.local_rank() + self.global_rank = self.node_rank * self.num_processes + self.local_rank + self.world_size = self.num_nodes * self.num_processes + + @property + def mp_spawn_kwargs(self): + return { + "args": (self.lightning_module.trainer, self.mp_queue), + "nprocs": self.num_processes, + } + + def start_training(self, trainer): + mp.spawn(self.new_process, **self.mp_spawn_kwargs) + # reset optimizers, since main process is never used for training and thus does not have a valid optim state + trainer.optimizers = [] + + def start_evaluating(self, trainer): + mp.spawn(self.new_process, **self.mp_spawn_kwargs) + + def start_predicting(self, trainer): + mp.spawn(self.new_process, **self.mp_spawn_kwargs) + + def new_process(self, process_idx, trainer, mp_queue): + self.mp_queue = mp_queue + + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # TODO: we moved it to the trainer.fit after calling pre_dispatch + # ... need to double check that it is the correct place + # self.trainer.call_setup_hook(self.model) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # move the model to the correct device + self.model_to_device() + + self.configure_ddp() + + self.barrier() + + results = trainer.run_stage() + + # persist info in ddp_spawn + self.transfer_distrib_spawn_state_on_fit_end(results) + + def post_dispatch(self): + # restore main state with best weights + best_path = self.mp_queue.get() + last_path = self.mp_queue.get() + self._results = self.mp_queue.get() + + # recover the weights of the processes trained in the children + self.__recover_child_process_weights(best_path, last_path) + + def pre_configure_ddp(self): + # if unset, default `find_unused_parameters` `True` + # Many models require setting this parameter to True, as there are corner cases + # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. + # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization + if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( + "find_unused_parameters", False + ): + rank_zero_warn( + "From PyTorch 1.7.0, Lightning ``manual_optimization`` needs to set ``find_unused_parameters=True`` " + "to properly work with DDP." + ) + self._ddp_kwargs["find_unused_parameters"] = True + + def configure_ddp(self): + self.pre_configure_ddp() + self._model = DistributedDataParallel( + LightningDistributedModule(self.model), + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + # TODO: this code is duplicated in DDP and DDPSpawn, make this a function + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) + + if not torch.distributed.is_initialized(): + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + + def determine_ddp_device_ids(self): + if self.root_device.type == "cpu": + return None + return [self.root_device.index] + + def on_save(self, checkpoint: dict) -> dict: + return checkpoint + + def transfer_distrib_spawn_state_on_fit_end(self, results): + checkpoint_callback = self.lightning_module.trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + if self.global_rank == 0 and self.mp_queue is not None: + rank_zero_warn("cleaning up ddp environment...") + + # save the last weights + last_path = None + if ( + self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None + and len(best_model_path) > 0 + ): + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) + + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + + def __recover_child_process_weights(self, best_path, last_path): + # transfer back the best path to the trainer + if self.lightning_module.trainer.checkpoint_callback: + self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path + # todo, pass also best score + + # load last weights + if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING: + ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) + self.lightning_module.load_state_dict(ckpt) + + def barrier(self, *args, **kwargs): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + return self.dist.broadcast(obj) + + def model_to_device(self): + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + """Run before precision plugin executes backward""" + if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: + prepare_for_backward(self.model, closure_loss) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor + + def training_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def post_training_step(self): + if not self.lightning_module.automatic_optimization: + self.model.require_backward_grad_sync = True diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py new file mode 100644 index 00000000000000..b196044937414b --- /dev/null +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -0,0 +1,340 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch.nn.parallel import DistributedDataParallel + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE + +if _DEEPSPEED_AVAILABLE: + import deepspeed + + +class LightningDeepSpeedModule(_LightningModuleWrapperBase): + + def __init__(self, pl_module: LightningModule, precision: int): + super().__init__(pl_module) + self.precision = precision + + def forward(self, *inputs, **kwargs): + if self.precision == 16: + inputs = self._move_float_tensors_to_half(inputs) + + return super().forward(*inputs, **kwargs) + + @staticmethod + def batch_to(data): + return data.half() + + def _move_float_tensors_to_half(self, batch: Any): + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) + return batch + + +class DeepSpeedPlugin(DDPPlugin): + distributed_backend = "deepspeed" + DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" + + def __init__( + self, + zero_optimization: bool = True, + stage: int = 2, + cpu_offload: bool = False, + contiguous_gradients: bool = True, + overlap_comm: bool = True, + allgather_partitions: bool = True, + reduce_scatter: bool = True, + allgather_bucket_size: int = 2e8, + reduce_bucket_size: int = 2e8, + zero_allow_untested_optimizer: bool = True, + config: Optional[Union[Path, str, dict]] = None, + logging_level: int = logging.WARN, + num_nodes: int = 1, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + loss_scale: float = 0, + initial_scale_power: int = 32, + loss_scale_window: int = 1000, + hysteresis: int = 2, + min_loss_scale: int = 1 + ) -> None: + """ + + Provides capabilities to run training using the DeepSpeed library, + with training optimizations for large billion parameter models. + `For more information: https://www.deepspeed.ai/`. + + .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change. + + Defaults have been set to enable ZeRO-Offload and some have been taken from the link below. + These defaults have been set generally, but may require tuning for optimum performance based on your model size. + `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`. + + Arguments: + + zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True) + + stage: Different stages of the ZeRO Optimizer. 0 is disabled, + 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2) + + cpu_offload: Enable offloading optimizer memory and computation to CPU + + contiguous_gradients: Copies gradients to a continuous buffer as they are produced. + Avoids memory fragmentation during backwards. Useful when training large models. (default: True) + + overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation. + This is a speed optimization when training across multiple GPUs/machines. (default: True) + + allgather_partitions: All gather updated parameters at the end of training step, + instead of using a series of broadcast collectives (default: True) + + reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True) + + allgather_bucket_size: Number of elements to allgather at once. + Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8) + + reduce_bucket_size: Number of elements to reduce at once. + Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8) + + zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a + DeepSpeed supported optimizer when using ZeRO (default: True) + + config: Pass in a deepspeed formatted config dict, + or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. + All defaults will be ignored if a config is passed in. (Default: ``None``) + + logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) + + loss_scale: Loss scaling value for FP16 training. + 0.0 results in dynamic loss scaling, otherwise static (Default: 0) + + initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed + by ``2^initial_scale_power`` (Default: 32) + + loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000) + + hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2) + + min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000) + + """ + if not _DEEPSPEED_AVAILABLE: + raise MisconfigurationException( + "To use the DeepSpeed plugin, you must have DeepSpeed installed." + " pip install deepspeed mpi4py" + ) + super().__init__( + parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment + ) + self.config = self._load_config(config) + if self.config is None: + # User has not overridden config, set defaults + self.config = self._create_default_config( + zero_optimization, + zero_allow_untested_optimizer, + stage=stage, + cpu_offload=cpu_offload, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size + ) + self._config_initialized = False + deepspeed.utils.logging.logger.setLevel(logging_level) + + # default FP16 parameters. + self.loss_scale = loss_scale + self.initial_scale_power = initial_scale_power + self.loss_scale_window = loss_scale_window + self.hysteresis = hysteresis + self.min_loss_scale = min_loss_scale + + def _load_config(self, config): + if config is None and self.DEEPSPEED_ENV_VAR in os.environ: + rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") + config = os.environ[self.DEEPSPEED_ENV_VAR] + if isinstance(config, str) or isinstance(config, Path): + if not os.path.isfile(config): + raise MisconfigurationException( + f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" + ) + with open(config) as f: + config = json.load(f) + return config + + def pre_dispatch(self): + self.init_deepspeed() + self.barrier() + + def init_deepspeed(self): + if not self._config_initialized: + self._format_config() + self._config_initialized = True + + precision = self.lightning_module.trainer.accelerator.precision + model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + + if self.lightning_module.trainer and self.lightning_module.trainer.training: + self._initialize_deepspeed_train(model) + else: + self._initialize_deepspeed_inference(model) + + def _init_scheduler_optimizer(self): + optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( + self.lightning_module + ) + if len(optimizers) > 1 or len(schedulers) > 1: + raise MisconfigurationException( + "DeepSpeed currently only supports single optimizer, single optional scheduler." + ) + scheduler = schedulers[0]['scheduler'] if len(schedulers) == 1 else None + optimizer = optimizers[0] + return optimizer, scheduler, optimizer_frequencies + + def _initialize_deepspeed_train(self, model): + if self.on_gpu: + torch.cuda.set_device(self.root_device) + optimizer, lightning_scheduler, optimizer_frequencies = None, None, None + if "optimizer" not in self.config: + rank_zero_info( + "You have not specified an optimizer or scheduler within the DeepSpeed config." + "Using `configure_optimizers` to define optimizer and scheduler." + ) + optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer() + model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) + model, optimizer, _, lr_scheduler = deepspeed.initialize( + args=SimpleNamespace(local_rank=self.local_rank), + model=model, + model_parameters=model_parameters, + optimizer=optimizer, + lr_scheduler=lightning_scheduler, + config_params=self.config, + ) + + # set optimizer for save/load, but deepspeed manages the specific optimizer logic + self.lightning_module.trainer.optimizers = [optimizer] + self.model = model + + def _initialize_deepspeed_inference(self, model): + # move the model to the correct device + self.model_to_device() + + self.pre_configure_ddp() + self.model = DistributedDataParallel( + model, + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + + def configure_scheduler(self, lr_scheduler): + scheduler = _get_default_scheduler_config() + scheduler["scheduler"] = lr_scheduler + return [scheduler] + + @property + def lightning_module(self): + # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early + module = getattr(self.model, "module", self.model) + return module.module if isinstance(module, LightningDeepSpeedModule) else module + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) + return distributed_sampler_kwargs + + def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: + # Skip initializing optimizers here as DeepSpeed handles optimizers via config. + # User may have specified config options instead in configure_optimizers, but this is handled + # via `_initialize_deepspeed_train` + return [], [], [] # empty optimizers, schedulers and frequencies + + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + # note: We rely on the deepspeed engine to carry out the step rather than the optimizer. + # internally, the engine has a reference to the optimizer already. + self.model.step(**kwargs) + + def _format_config(self): + if self.config is None: + raise MisconfigurationException( + "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." + " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" + ) + self._format_batch_size_and_grad_accum_config() + self._format_precision_config() + + def _format_batch_size_and_grad_accum_config(self): + if "gradient_accumulation_steps" in self.config: + raise MisconfigurationException( + "Within the DeepSpeed config, do not set gradient_accumulation_steps" + " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." + ) + if "train_micro_batch_size_per_gpu" not in self.config: + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed + batch_size = self.lightning_module.train_dataloader().batch_size + self.config["train_micro_batch_size_per_gpu"] = batch_size + self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches + if "gradient_clipping" not in self.config: + self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + + def _format_precision_config(self): + + amp_type = self.lightning_module.trainer.accelerator_connector.amp_type + amp_level = self.lightning_module.trainer.accelerator_connector.amp_level + precision = self.lightning_module.trainer.accelerator_connector.precision + if precision == 16: + if "fp16" not in self.config and amp_type == AMPType.NATIVE: + # FP16 is a DeepSpeed standalone AMP implementation + rank_zero_info("Enabling DeepSpeed FP16.") + self.config["fp16"] = { + "enabled": True, + "loss_scale": self.loss_scale, + "initial_scale_power": self.initial_scale_power, + "loss_scale_window": self.loss_scale_window, + "hysteresis": self.hysteresis, + "min_loss_scale": self.min_loss_scale + } + elif "amp" not in self.config and amp_type == AMPType.APEX: + rank_zero_only("Enabling DeepSpeed APEX Implementation.") + self.config["amp"] = { + "enabled": True, + "opt_level": amp_level, + } + if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config): + raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.") + + def _create_default_config( + self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs + ) -> Dict: + if zero_optimization: + return {"zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs} + return {} diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py new file mode 100644 index 00000000000000..a8e42e0fa747af --- /dev/null +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +import torch +from torch.nn import DataParallel + +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.overrides.data_parallel import LightningParallelModule +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class DataParallelPlugin(ParallelPlugin): + + def __init__(self, parallel_devices: Optional[List[torch.device]]): + super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + + def setup(self, model): + # model needs to be moved to the device before it is wrapped + model.to(self.root_device) + self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) + + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all parallel processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + *args: ignored for DP + **kwargs: ignored for DP + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + else: + + def _reduce(t: torch.Tensor): + dtype_tensor = t.dtype + return t.float().mean().type(dtype_tensor) + + tensor = apply_to_collection(tensor, torch.Tensor, _reduce) + + return tensor + + @property + def root_device(self): + return self.parallel_devices[0] + + def model_to_device(self): + # no need to do anything when model is wrapped in torch.nn.DataParallel + pass + + def barrier(self, *args, **kwargs): + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj + + def reduce_boolean_decision(self, decision: bool) -> bool: + return decision + + def training_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def training_step_end(self, output): + return self.reduce(output) + + def validation_step_end(self, output): + return self.reduce(output) + + def test_step_end(self, output): + return self.reduce(output) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py new file mode 100644 index 00000000000000..8d0add27cbb29c --- /dev/null +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -0,0 +1,187 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import ExitStack +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as torch_distrib +from torch.optim.lr_scheduler import _LRScheduler, Optimizer + +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + + +class HorovodPlugin(ParallelPlugin): + + def __init__(self, parallel_devices: Optional[List[torch.device]] = None): + super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + return distributed_sampler_kwargs + + def setup(self, model): + self._model = model + + self.global_rank = hvd.rank() + self.local_rank = hvd.local_rank() + self.world_size = hvd.size() + rank_zero_only.rank = self.global_rank + + self.model_to_device() + + def pre_dispatch(self): + + def _unpack_lightning_optimizer(opt): + return opt._optimizer if isinstance(opt, LightningOptimizer) else opt + + optimizers = self.lightning_module.trainer.optimizers + optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers] + + # Horovod: scale the learning rate by the number of workers to account for + # increased total batch size + for optimizer in optimizers: + for param_group in optimizer.param_groups: + param_group["lr"] *= hvd.size() + + # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR + lr_schedulers = self.lightning_module.trainer.lr_schedulers + for scheduler in lr_schedulers: + scheduler = scheduler["scheduler"] + if isinstance(scheduler, _LRScheduler): + scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] + + # Horovod: broadcast parameters & optimizer state to ensure consistent initialization + hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) + for optimizer in optimizers: + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + def _filter_named_parameters(model, optimizer): + opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])]) + return [(name, p) for name, p in model.named_parameters() if p in opt_params] + + # Horovod: wrap optimizers to perform gradient aggregation via allreduce + optimizers = [ + hvd.DistributedOptimizer( + optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) + ) for optimizer in optimizers + ] + self.lightning_module.trainer.accelerator.optimizers = optimizers + + def start_training(self, trainer): + with ExitStack() as stack: + for optimizer in trainer.optimizers: + # Synchronization will be performed explicitly following backward() + stack.enter_context(optimizer.skip_synchronize()) + + # set up training routine + self._results = trainer.run_stage() + + # Make sure all workers have finished training before returning to the user + hvd.join() + + def start_evaluating(self, trainer): + with ExitStack(): + self._results = trainer.run_stage() + + # Make sure all workers have finished training before returning to the user + hvd.join() + + def start_predicting(self, trainer): + with ExitStack(): + # set up training routine + self._results = trainer.run_stage() + + # Make sure all workers have finished training before returning to the user + hvd.join() + + def barrier(self, *args, **kwargs): + if torch_distrib.is_initialized(): + hvd.join() + + def broadcast(self, obj: object, src: int = 0) -> object: + obj = hvd.broadcast_object(obj, src) + return obj + + def model_to_device(self): + if self.on_gpu: + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if group is not None: + raise ValueError( + "Horovod does not support allreduce using a subcommunicator at this time. " + "Unset `group`." + ) + + if reduce_op in (None, "avg", "mean"): + reduce_op = hvd.Average + elif reduce_op == "sum": + reduce_op = hvd.Sum + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + hvd.join() + return hvd.allreduce(tensor, op=reduce_op) + + def all_gather( + self, + result: Union[torch.Tensor], + group: Optional[Any] = group.WORLD, + sync_grads: bool = False + ) -> torch.Tensor: + if group is not None and group != group.WORLD: + raise ValueError( + "Horovod does not support allgather using a subcommunicator at this time. " + "Unset `group`." + ) + + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + hvd.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + # synchronize all horovod optimizers. + for optimizer in self.lightning_module.trainer.optimizers: + optimizer.synchronize() diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py new file mode 100644 index 00000000000000..d9a8e70588c437 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -0,0 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, List, Optional + +import torch +from torch.nn.parallel import DistributedDataParallel + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp + + +class ParallelPlugin(TrainingTypePlugin, ABC): + + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + ): + super().__init__() + self.parallel_devices = parallel_devices + self.cluster_environment = cluster_environment + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + + @property + @abstractmethod + def root_device(self): + raise NotImplementedError + + @property + def on_gpu(self): + return self.root_device.type == "cuda" and torch.cuda.is_available() + + @property + def lightning_module(self): + return unwrap_lightning_module(self._model) + + @property + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) + return distributed_sampler_kwargs + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = bool(decision == self.world_size) + return decision + + @property + def torch_distributed_backend(self): + torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND") + if torch_backend is None: + torch_backend = "nccl" if self.on_gpu else "gloo" + return torch_backend + + @staticmethod + def configure_sync_batchnorm(model: LightningModule) -> LightningModule: + """ + Add global batchnorm for a model spread across multiple GPUs and nodes. + + Override to synchronize batchnorm between specific process groups instead + of the whole world or use a different sync_bn like `apex`'s version. + + Args: + model: pointer to current :class:`LightningModule`. + + Return: + LightningModule with batchnorm layers synchronized between process groups + """ + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + return model + + @contextmanager + def block_backward_sync(self): + """ + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, DistributedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py new file mode 100644 index 00000000000000..3e0f57daef0010 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from contextlib import suppress +from typing import Callable, List, Optional + +import torch + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities import _RPC_AVAILABLE + +DEFAULT_RPC_TIMEOUT_SEC = 60. +if _RPC_AVAILABLE: + from torch.distributed import rpc + + with suppress(ModuleNotFoundError, ImportError): + from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC + + +class RPCPlugin(DDPPlugin): + """ + Backbone for RPC Plugins built on top of DDP. + RPC introduces different communication behaviour than DDP. Unlike DDP, processes potentially are not + required to run the same code as the main process. + This leads to edge cases where logic needs to be re-defined. This class contains special cases + that need to be addressed when using RPC communication when building custom RPC Plugins. + """ + + def __init__( + self, + rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, + parallel_devices: Optional[List[torch.device]] = None, + num_nodes: Optional[int] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + sync_batchnorm: Optional[bool] = None, + **kwargs + ): + self.rpc_timeout_sec = rpc_timeout_sec + self._is_rpc_initialized = False + super().__init__( + parallel_devices=parallel_devices, + num_nodes=num_nodes, + cluster_environment=cluster_environment, + sync_batchnorm=sync_batchnorm, + **kwargs + ) + + def init_rpc_connection(self, global_rank: int, world_size: int) -> None: + os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') + rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + rpc._set_rpc_timeout(self.rpc_timeout_sec) + self._is_rpc_initialized = True + + def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None: + """ + Override to save model to disk. + This is required as the main process will be required to handle aggregating model states from RPC processes. + + Args: + trainer: The trainer object. + save_model_fn: The saving function to save final model. + filepath: The filepath to save the model to. + """ + raise NotImplementedError + + def exit_rpc_process(self): + if self._is_rpc_initialized: + torch.distributed.rpc.shutdown() + self._is_rpc_initialized = False + + @property + def rpc_enabled(self) -> bool: + return True diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py new file mode 100644 index 00000000000000..ba26fc9f58ec54 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -0,0 +1,408 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import logging +import os +from typing import Callable, List, Optional + +import torch +import torch.distributed as torch_distrib +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.distributed import LightningDistributedModule +from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _FAIRSCALE_PIPE_AVAILABLE: + import fairscale.nn.model_parallel as mpu + from fairscale.nn import PipeRPCWrapper + from fairscale.nn.pipe import balance as pipe_balance + from fairscale.nn.pipe import rpc as rpc_pipe + from fairscale.nn.pipe.pipeline import PipelineStyle + +log = logging.getLogger(__name__) + + +class RPCSequentialPlugin(RPCPlugin): + + def __init__( + self, + balance: Optional[List[int]] = None, + microbatches: int = 8, + checkpoint: str = 'except_last', + balance_mode: str = "balance_by_size", + pipelined_backward: Optional[bool] = True, + rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, + **kwargs + ): + """ + Provides sequential model parallelism for :class:`nn.Sequential ` module. + If the module requires lots of memory, Pipe can be used to reduce this by leveraging multiple GPUs. + + .. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965 + + Pipeline parallelism comes with with checkpointing to reduce peak + memory required to train while minimizing device under-utilization. + This is turned on by default and can be turned off via the checkpoint argument. + + You should determine the balance when defining the plugin, + or you can pass an example input array via the LightningModule to infer a balance. + The module will be partitioned into multiple devices according to the given balance. You may also rely on + your own heuristics to find your own optimal configuration. + + Args: + balance: The balance of the model, i.e [2, 2] (two layers on each GPU). + If not provided assumes user provides an input example array to find a balance on all GPUs. + + microbatches: Allows for parallelization to reduce device utilization + by splitting the batch into further smaller batches. + + checkpoint: Enables gradient checkpointing. ['always', 'except_last', 'never'] + + balance_mode: Type of balance heuristic to use if balance to be inferred. + + - 'balance_by_size': checks memory usage of each layer and determines balance + + - 'balance_by_time': checks time of each layer and determines balance + + pipelined_backward: if True, call torch.autograd.backward once per microbatch on the + + backward pass (instead of once for the whole batch). This works + around a potential deadlock in pytorch when using tensor parallelism + at the same time. Defaults to `True` if + `get_model_parallel_world_size() > 1` + """ + self._check_pipe_available() + super().__init__(rpc_timeout_sec=rpc_timeout_sec, **kwargs) + + self.balance = balance + + self.microbatches = microbatches + self.checkpoint = checkpoint + self.balance_mode = balance_mode + self.pipelined_backward = pipelined_backward + self._main_rpc_process = True + + def init_ddp_connection( + self, + global_rank: int, + world_size: int, + ) -> None: + if self.lightning_module.trainer.amp_backend is not None: + raise MisconfigurationException( + '`RPCSequentialPlugin` is currently not supported in Automatic Mixed Precision' + ) + + if self._skip_init_connections(): + return + super().init_ddp_connection( + global_rank=global_rank, + world_size=world_size, + ) + super().init_rpc_connection(global_rank=global_rank, world_size=world_size) + model = self.lightning_module + self.gpus_per_model = self._infer_check_num_gpus() + self.init_model_parallel_groups() + self.set_main_rpc_process() + + self._check_sequential_model_exists(model) + + # check if user given balance is valid + if self.balance is not None: + self._assert_valid_model_balance() + + if self.main_rpc_process: + if self.balance is None: + self._infer_model_balance() + self.init_pipe_module() + else: + self.handle_transferred_pipe_module() + self.exit_rpc_process() + + def _infer_model_balance(self): + log.info(f'Inferring model balance using {self.balance_mode} mode') + model = self.lightning_module + if model.example_input_array is None: + raise MisconfigurationException( + 'Please set example_input_array to your model, so we can infer the right model balance for you' + ) + balance_func = getattr(pipe_balance, self.balance_mode) + self.balance = balance_func(self.gpus_per_model, model.sequential_module, model.example_input_array) + self._sync_balance_to_all_parallel_groups() + + log.info(f'The following model balance {self.balance.tolist()} was inferred using {self.balance_mode} mode') + + def _sync_balance_to_all_parallel_groups(self, main_rank=0): + """ + Ensures that we sync the balance to all main processes, so that the balance is the same per replica. + Args: + main_rank: The rank with the balance we'd like to replicate. + """ + self.balance = torch.tensor(self.balance, dtype=torch.int, device='cuda') + # Ensure we sync to all processes within the main data parallel group + # We use the data parallel group as all main processes are found within the same group + torch_distrib.broadcast(self.balance, src=main_rank, group=mpu.get_data_parallel_group()) + self.balance = self.balance.cpu() + + def _check_sequential_model_exists(self, model): + if not hasattr(model, "sequential_module") or not isinstance(model.sequential_module, nn.Sequential): + raise MisconfigurationException( + 'Could not find a PipeLightningModule within the model. ' + 'Did you set your sequential model as the `sequential_module` attribute of your model?' + ) + + def _find_and_init_pipe_module(self, model): + if hasattr(model, "sequential_module") and isinstance(model.sequential_module, LightningPipeModule): + # model has been wrapped already + return + elif hasattr(model, "sequential_module") and isinstance(model.sequential_module, nn.Sequential): + # try to wrap model for the user + model.sequential_module = LightningPipeModule( + model.sequential_module, + balance=self.balance, + microbatches=self.microbatches, + checkpoint=self.checkpoint, + ) + # Update references for workers to access correct lightning functions when calling RPC + model.sequential_module.trainer = model.trainer + model.sequential_module.configure_optimizers = model.configure_optimizers + + # Update references for main process to access correct lightning functions when calling RPC + model.sequential_module.module.model.trainer = model.trainer + model.sequential_module.module.model.configure_optimizers = model.configure_optimizers + + self.model = model + + else: + raise MisconfigurationException( + 'Could not find a PipeLightningModule within the model. ' + 'Did you defined set your sequential model as a `sequential_module` attribute of your model?' + ) + + def _assert_valid_model_balance(self): + model = self.lightning_module + if sum(self.balance) != len(model.sequential_module): + raise MisconfigurationException( + f'The provided balance sum: {sum(self.balance)} does not' + f' match your Sequential length: {len(model.sequential_module)}' + ) + + def _skip_init_connections(self): + """ + Skip initialization if torch is already initialized and we're in testing. + Returns: Whether to skip initialization + + """ + return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TrainerState.FITTING + + def init_model_parallel_groups(self): + num_model_parallel = 1 # TODO currently no support for vertical model parallel + mpu.initialize_model_parallel(model_parallel_size_=num_model_parallel, pipeline_length=self.gpus_per_model) + + def _infer_check_num_gpus(self): + """ + Infer the number of GPUs per model. + + Returns: The appropriate balance for the model + """ + if isinstance(self.balance, list): + if len(self.balance) != (self.world_size / self.num_nodes): + raise MisconfigurationException( + "Pipe currently only supports splitting the module onto all available GPUs" + ) + # User has defined a balance for his model + return len(self.balance) + # Assume that the user wants to balance his model on all GPUs + return self.world_size + + def handle_transferred_pipe_module(self) -> None: + if self.lightning_module.trainer.state == TrainerState.FITTING: + torch_distrib.barrier() # Ensure we await main process initialization + # Add trainer/configure_optimizers to the pipe model for access in all worker processes + rpc_pipe.PipeModel.trainer = self.lightning_module.trainer + del rpc_pipe.PipeModel.trainer.model.sequential_module + rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel + rpc_pipe.PipeModel.configure_optimizers = self.lightning_module.configure_optimizers + + def init_pipe_module(self) -> None: + # Create pipe_module + model = self.lightning_module + self._find_and_init_pipe_module(model) + if self.lightning_module.trainer.state == TrainerState.FITTING: + torch_distrib.barrier() # Ensure we join main process initialization + model.sequential_module.foreach_worker(register_optimizers, include_self=True) + + # TODO: Move this to the connector + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + """Run before precision plugin executes backward""" + + def configure_ddp(self): + if self.main_rpc_process: + self.pre_configure_ddp() + + self._model = DistributedDataParallel( + LightningDistributedModule(self.model), + device_ids=self.determine_ddp_device_ids(), + process_group=mpu.get_data_parallel_group(), + **self._ddp_kwargs, + ) + # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel + self._model.require_backward_grad_sync = False + + @rank_zero_only + def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None: + model = self.lightning_module + if not hasattr(model.sequential_module, "foreach_worker"): + return + current_layers = model.sequential_module + model.sequential_module.foreach_worker( + save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True + ) + model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model) + save_model_fn(trainer, filepath) + model.sequential_module = current_layers + + def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: + model.sequential_module.foreach_worker( + run_optimizer, { + "opt_idx": opt_idx, + "args": args, + "kwargs": kwargs + }, include_self=False + ) + + @property + def distributed_sampler_kwargs(self): + return dict( + num_replicas=mpu.get_data_parallel_world_size(), + rank=mpu.get_data_parallel_rank(), + ) + + @property + def data_parallel_group(self): + return mpu.get_data_parallel_group() + + def set_main_rpc_process(self): + self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0 + + @property + def main_rpc_process(self) -> bool: + return self._main_rpc_process + + @main_rpc_process.setter + def main_rpc_process(self, is_main_process): + self._main_rpc_process = is_main_process + + def barrier(self, name: Optional[str] = None) -> None: + if torch_distrib.is_initialized() and self.main_rpc_process: + torch_distrib.barrier(group=self.data_parallel_group) + + def _check_pipe_available(self): + if not _FAIRSCALE_PIPE_AVAILABLE: + raise MisconfigurationException( + 'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.' + ) + + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: + """Hook to do something after each optimizer step.""" + if self.rpc_enabled and self.main_rpc_process: + # Initialize optimizer step on main process + self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs) + + def post_training_step(self): + if self.main_rpc_process: + super().post_training_step() + + def start_training(self, trainer) -> None: + if self.main_rpc_process: + super().start_training(trainer) + + def start_evaluating(self, trainer) -> None: + if self.main_rpc_process: + super().start_evaluating(trainer) + + +class LightningPipeModule(nn.Module): + """ + This class wraps Fairscale Pipe and PipeRCPWrapper class. + """ + + def __init__(self, module: nn.Sequential, balance: List[int], microbatches: int = 8, checkpoint='never'): + super().__init__() + self.module = module + self.balance = balance + self.microbatches = microbatches + self.checkpoint = checkpoint + self._init_pipe() + + def _init_pipe(self): + device = torch.device("cuda", torch_distrib.get_rank()) + + self.module = PipeRPCWrapper( + module=self.module, + balance=self.balance, + chunks=self.microbatches, + style=PipelineStyle.MultiProcess, + input_device=device, + worker_map=self.get_worker_map(), + checkpoint=self.checkpoint, + ) + + def foreach_worker(self, *args, **kwargs): + self.module.foreach_worker(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def get_worker_map(self): + # TODO, is this correct with multinodes? We also assume "worker" is the same as defined in the RPCPlugin + return {rank: f"worker{rank}" for rank in range(torch_distrib.get_world_size())} + + +def register_optimizers(ctx, model): + optimizers, lr_schedulers, optimizer_frequencies = model.trainer.init_optimizers(model) + model.trainer.optimizers = optimizers + model.trainer.lr_schedulers = lr_schedulers + model.trainer.optimizer_frequencies = optimizer_frequencies + + +def run_optimizer(ctx, model): + trainer = model.trainer + opt_idx = ctx["opt_idx"] + optimizer = trainer.optimizers[opt_idx] + optimizer.step(*ctx["args"], **ctx["kwargs"]) + + +def save_layers_on_all_rank_zero_workers(ctx, model): + gpus_per_model = ctx["gpus_per_model"] + rank = torch_distrib.get_rank() + if rank in range(gpus_per_model): + seq = list(model.children())[0] + torch.save(seq, f"seq_{rank}.pt") + + +def load_sequential_from_saved_layers(gpus_per_model): + partial_seqs = [torch.load(f"seq_{rank}.pt", map_location='cpu') for rank in range(gpus_per_model)] + seq = nn.Sequential() + for p_seq in partial_seqs: + for name, child in p_seq.named_children(): + seq.add_module(name, child) + # delete tmp files + [os.remove(f"seq_{rank}.pt") for rank in range(gpus_per_model)] + return seq diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py new file mode 100644 index 00000000000000..7536ef9b1d856d --- /dev/null +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.optimizer import is_lightning_optimizer +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only + +if _FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + from fairscale.optim import OSS + + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded + + +class DDPShardedPlugin(DDPPlugin): + + def configure_ddp(self): + self._wrap_optimizers() + self._model = ShardedDataParallel( + LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers + ) + + def _reinit_optimizers_with_oss(self): + optimizers = self.lightning_module.trainer.optimizers + for x, optimizer in enumerate(optimizers): + if is_lightning_optimizer(optimizer): + optimizer = optimizer._optimizer + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) + optimizers[x] = zero_optimizer + del optimizer + trainer = self.lightning_module.trainer + trainer.optimizers = optimizers + trainer.convert_to_lightning_optimizers() + + def _wrap_optimizers(self): + if self.model.trainer.state != TrainerState.FITTING: + return + self._reinit_optimizers_with_oss() + + def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: + if is_lightning_optimizer(optimizer): + optimizer = optimizer._optimizer + optimizer.consolidate_state_dict() + return self._optim_state_dict(optimizer) + + @rank_zero_only + def _optim_state_dict(self, optimizer): + """ + Retrieves state dict only on rank 0, which contains the entire optimizer state after calling + :meth:`consolidate_state_dict`. + """ + return optimizer.state_dict() + + @property + def lightning_module(self) -> LightningModule: + return unwrap_lightning_module_sharded(self._model) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py new file mode 100644 index 00000000000000..7aadf797e160ac --- /dev/null +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -0,0 +1,67 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only + +if _FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + from fairscale.optim import OSS + + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded + + +class DDPSpawnShardedPlugin(DDPSpawnPlugin): + + def configure_ddp(self): + self._wrap_optimizers() + self._model = ShardedDataParallel( + LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers + ) + + def _reinit_optimizers_with_oss(self): + optimizers = self.lightning_module.trainer.optimizers + for x, optimizer in enumerate(optimizers): + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) + optimizers[x] = zero_optimizer + del optimizer + trainer = self.lightning_module.trainer + trainer.optimizers = optimizers + + def _wrap_optimizers(self): + if self.model.trainer.state != TrainerState.FITTING: + return + self._reinit_optimizers_with_oss() + + def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: + if isinstance(optimizer, OSS): + optimizer.consolidate_state_dict() + return self._optim_state_dict(optimizer) + + @rank_zero_only + def _optim_state_dict(self, optimizer): + """ + Retrieves state dict only on rank 0, which contains the entire optimizer state after calling + :meth:`consolidate_state_dict`. + """ + return optimizer.state_dict() + + @property + def lightning_module(self) -> LightningModule: + return unwrap_lightning_module_sharded(self._model) diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py new file mode 100644 index 00000000000000..d70779adf3ba1b --- /dev/null +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin + + +class SingleDevicePlugin(TrainingTypePlugin): + + def __init__(self, device: torch.device): + super().__init__() + self.device: torch.device = device + self.global_rank = 0 + self.local_rank = 0 + self.world_size = 1 + + @property + def on_tpu(self) -> bool: + return False + + @property + def on_gpu(self) -> bool: + return self.device.type == "cuda" and torch.cuda.is_available() + + def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + tensor: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ + return tensor + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return tensor + + @property + def root_device(self) -> torch.device: + return self.device + + def model_to_device(self) -> None: + if self.on_gpu: + torch.cuda.set_device(self.root_device) + + self._model.to(self.root_device) + + def setup(self, model: torch.nn.Module) -> torch.nn.Module: + self.model_to_device() + return self.model + + @property + def is_global_zero(self) -> bool: + return True + + def barrier(self, *args, **kwargs) -> None: + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py new file mode 100644 index 00000000000000..b8d670ff16881c --- /dev/null +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional, Union + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin +from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.apply_func import move_data_to_device + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + + +class SingleTPUPlugin(SingleDevicePlugin): + + def __init__(self, device: Union[torch.device, int]): + if isinstance(device, int): + device = xm.xla_device(device) + super().__init__(device) + + self.tpu_local_core_rank = 0 + self.tpu_global_core_rank = 0 + + def on_tpu(self) -> bool: + return True + + def model_to_device(self) -> None: + self.model.to(self.root_device) + + def pre_dispatch(self) -> None: + if isinstance(self.device, int): + self.device = xm.xla_device(self.device) + + self.tpu_local_core_rank = xm.get_local_ordinal() + self.tpu_global_core_rank = xm.get_ordinal() + + def post_dispatch(self) -> None: + model = self.lightning_module + + if on_colab_kaggle(): + rank_zero_warn("cleaning up... please do not interrupt") + self.save_spawn_weights(model) + + def save_spawn_weights(self, model: LightningModule) -> Optional[str]: + """ + Dump a temporary checkpoint after ddp ends to get weights out of the process + """ + path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") + model.trainer.save_checkpoint(path) + return path + + def on_save(self, checkpoint: dict) -> dict: + """ + Move XLA tensors to CPU before saving + Recommended on XLA Guide: + https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors + """ + return move_data_to_device(checkpoint, torch.device("cpu")) + + @property + def is_distributed(self): + return False diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py new file mode 100644 index 00000000000000..ba074e7cfb2069 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -0,0 +1,307 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +import re +from typing import Any, Dict, Iterable, List, Optional, Union + +import torch +import torch.multiprocessing as mp + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import seed_everything + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as xla_pl + import torch_xla.distributed.xla_multiprocessing as xmp + from torch_xla.core.xla_model import rendezvous + from torch_xla.distributed.parallel_loader import ParallelLoader +else: + xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 + + +class TPUSpawnPlugin(DDPSpawnPlugin): + + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + num_nodes: int = 1, + **kwargs: Dict[str, Any] + ) -> None: + super().__init__( + parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs + ) + self.tpu_local_core_rank = 0 + self.start_method = None + + def setup(self, model: torch.nn.Module) -> torch.nn.Module: + self.create_mp_queue() + return self.model + + def create_mp_queue(self): + self.start_method = 'fork' + smp = mp.get_context(self.start_method) + self.mp_queue = smp.SimpleQueue() + + @property + def distributed_sampler_kwargs(self) -> dict: + return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + @property + def is_distributed(self): + return self.world_size != 1 + + def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader: + device = xm.xla_device() + dataloader = xla_pl.ParallelLoader(dataloader, [device]) + dataloader = dataloader.per_device_loader(device) + return dataloader + + def configure_ddp(self) -> None: + pass + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + pass + + def set_world_ranks(self, process_idx: int) -> None: + self.tpu_local_core_rank = xm.get_local_ordinal() + self.tpu_global_core_rank = xm.get_ordinal() + self.global_rank = self.tpu_local_core_rank + self.world_size = self.num_nodes * self.num_processes + + def new_process(self, process_idx: int, trainer, mp_queue) -> None: + self.mp_queue = mp_queue + + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.global_rank + + if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: + trainer.progress_bar_callback.disable() + + self.model_to_device() + trainer.accelerator.setup_optimizers(trainer) + trainer.precision_plugin.connect(self._model, None, None) + + self.barrier("pre-run-stage") + + results = trainer.run_stage() + + self.__save_end_of_training_weights(self.lightning_module) + self.transfer_distrib_spawn_state_on_fit_end(results) + + self.barrier("end-process") + + def __save_end_of_training_weights(self, model: LightningModule) -> None: + # when training ends on these platforms dump weights to get out of the main process + if on_colab_kaggle(): + rank_zero_warn("cleaning up... please do not interrupt") + self.save_spawn_weights(model) + + def model_to_device(self) -> None: + self._model.to(xm.xla_device()) + + def barrier(self, name: Optional[str] = None) -> None: + rendezvous(name) + + def transfer_distrib_spawn_state_on_fit_end(self, results): + checkpoint_callback = self.lightning_module.trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + if self.mp_queue is not None: + rank_zero_warn("cleaning up ddp environment...") + + # save the last weights + last_path = None + if ( + self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None + and len(best_model_path) > 0 + ): + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.save(self.lightning_module.state_dict(), last_path) + + if self.global_rank == 0: + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + + def save(self, state_dict: Dict, path: str) -> None: + """ + Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``. + The rendez-vous doesn't affect directly saving. + We can ignore the ``RuntimeError`` to reduce friction with TPUs. + """ + try: + xm.save(state_dict, path) + except RuntimeError as e: + if "Failed to meet rendezvous" not in str(e): + raise e + + def broadcast(self, obj: object, src: int = 0) -> object: + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) + data = xm.all_gather(data_tensor) + buffer = io.BytesIO(data.cpu().byte().numpy()) + obj = torch.load(buffer) + return obj + + def load_spawn_weights(self, original_model: LightningModule) -> LightningModule: + """ + Load the temp weights saved in the process + To recover the trained model from the ddp process we load the saved weights + """ + + loaded_model = original_model + + if self.is_global_zero: + # load weights saved in ddp + path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") + loaded_model = original_model.__class__.load_from_checkpoint(path) + + # copy loaded weights to old model + original_model.load_state_dict(loaded_model.state_dict()) + + # remove ddp weights + os.remove(path) + + return loaded_model + + def save_spawn_weights(self, model: LightningModule) -> Optional[str]: + """ + Dump a temporary checkpoint after ddp ends to get weights out of the process + """ + if model.trainer.is_global_zero: + path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt") + model.trainer.save_checkpoint(path) + return path + + def reduce_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.device) + decision = self.reduce(decision, "sum") + decision = bool(decision == self.world_size) + return decision + + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, device=self.device) + + _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if _invalid_reduce_op or _invalid_reduce_op_str: + raise MisconfigurationException( + "Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation." + ) + + output = xm.mesh_reduce('reduce', output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output + + def post_dispatch(self) -> None: + # TODO: Check if trainer references can be resolved otherwise + model = self.lightning_module + + # restore main state with best weights + best_path = self.mp_queue.get() + last_path = self.mp_queue.get() + self._results = self.mp_queue.get() + + # transfer back the best path to the trainer + if self.lightning_module.trainer.checkpoint_callback is not None: + self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path + # todo, pass also bets score + + # load last weights + if last_path and model.trainer.state == TrainerState.FITTING: + ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt) + + self._model = model + + # when training completes, load the weights back in main process + self.__load_weights_on_main_process() + + def __load_weights_on_main_process(self) -> None: + model = self.lightning_module + + # load weights if not interrupted + if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING: + self.load_spawn_weights(model) + + self._model = model + + def _close_logger(self, trainer) -> None: + if trainer.logger is not None: + trainer.logger.finalize("success") + + @property + def xmp_spawn_kwargs(self): + return { + "args": (self.lightning_module.trainer, self.mp_queue), + "nprocs": len(self.parallel_devices), + "start_method": self.start_method + } + + def start_training(self, trainer) -> None: + # todo: precision pluging is call in accelerator setup and should be moved + if 'XLA_USE_BF16' in os.environ: + del os.environ["XLA_USE_BF16"] + self._close_logger(trainer) + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + + def start_evaluating(self, trainer) -> None: + self._close_logger(trainer) + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + + def start_predicting(self, trainer) -> None: + xmp.spawn(self.new_process, **self.xmp_spawn_kwargs) + + def training_step(self, *args, **kwargs): + return self.lightning_module.training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.lightning_module.validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.lightning_module.test_step(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + # Todo: TypeError: 'mappingproxy' object does not support item assignment + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py new file mode 100644 index 00000000000000..1eac88212e0fbf --- /dev/null +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -0,0 +1,244 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union + +import torch +from torch.nn import Module +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save + +if TYPE_CHECKING: + from pytorch_lightning.trainer.trainer import Trainer + + +class TrainingTypePlugin(Plugin, ABC): + """A Plugin to change the behaviour of the training, validation and test-loop.""" + + def __init__(self) -> None: + self._model = None + self._results = None + self._call_configure_sharded_model_hook = True + + def connect(self, model: 'Module') -> None: + """Called by the accelerator to connect the accelerator and the model with this plugin""" + self.model = model + + def setup_environment(self) -> None: + """ + Setup any processes or distributed connections. + This is called before the LightningModule/DataModule setup hook + which allows the user to access the accelerator environment before setup is complete. + """ + + def setup(self, model: 'Module') -> None: + """Called by the accelerator to finish setup.""" + + @property + @abstractmethod + def on_gpu(self) -> bool: + """Returns whether the current process is done on GPU""" + + @property + @abstractmethod + def root_device(self) -> torch.device: + """Returns the root device""" + + @abstractmethod + def model_to_device(self) -> None: + """Moves the model to the correct device""" + + @property + @abstractmethod + def is_global_zero(self) -> bool: + """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" + + @abstractmethod + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """ + Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ + + @abstractmethod + def barrier(self, name: Optional[str] = None) -> None: + """Forces all possibly joined processes to wait for each other""" + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes""" + + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes""" + return decision + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + """Run before precision plugin executes backward""" + + def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + """Run after precision plugin executes backward""" + + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: + """Hook to do something after each optimizer step.""" + + @property + def model(self) -> Module: + """Returns the potentially wrapped LightningModule""" + return self._model + + @model.setter + def model(self, new_model: Module) -> None: + self._model = new_model + + @property + def lightning_module(self) -> LightningModule: + """Returns the pure LightningModule without potential wrappers""" + return unwrap_lightning_module(self._model) + + @property + def results(self) -> Any: + """ + The results of the last training/testing run will be cached here. + In distributed training, we make sure to transfer the results to the appropriate master process. + """ + # TODO: improve these docs + return self._results + + @property + def rpc_enabled(self) -> bool: + return False + + def start_training(self, trainer: 'Trainer') -> None: + # double dispatch to initiate the training loop + self._results = trainer.run_stage() + + def start_evaluating(self, trainer: 'Trainer') -> None: + # double dispatch to initiate the test loop + self._results = trainer.run_stage() + + def start_predicting(self, trainer: 'Trainer') -> None: + # double dispatch to initiate the predicting loop + self._results = trainer.run_stage() + + def training_step(self, *args, **kwargs): + return self.lightning_module.training_step(*args, **kwargs) + + def post_training_step(self): + pass + + def validation_step(self, *args, **kwargs): + return self.lightning_module.validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.lightning_module.test_step(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) + + def training_step_end(self, output): + return output + + def validation_step_end(self, output): + return output + + def test_step_end(self, output): + return output + + def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: + return checkpoint + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return dataloader + + def init_optimizers(self, trainer: "Trainer", model: LightningModule): + return trainer.init_optimizers(model) + + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return False + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + # dump states as a checkpoint dictionary object + if self.is_global_zero: + checkpoint = self.on_save(checkpoint) + try: + # write the checkpoint dictionary on the file + atomic_save(checkpoint, filepath) + except AttributeError as err: + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + rank_zero_warn( + 'Warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' + ) + atomic_save(checkpoint, filepath) + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + """ + Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + yield + + @property + def call_configure_sharded_model_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self._call_configure_sharded_model_hook + + @call_configure_sharded_model_hook.setter + def call_configure_sharded_model_hook(self, mode: bool) -> None: + self._call_configure_sharded_model_hook = mode diff --git a/pytorch_lightning/plugins/training_type/utils.py b/pytorch_lightning/plugins/training_type/utils.py new file mode 100644 index 00000000000000..eddb9077116dcd --- /dev/null +++ b/pytorch_lightning/plugins/training_type/utils.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +def on_colab_kaggle() -> bool: + return bool(os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE")) diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index 683baccafa8589..6ac6e16c185290 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -22,12 +22,12 @@ Enable simple profiling ----------------------- -If you only wish to profile the standard actions, you can set `profiler=True` when constructing -your `Trainer` object. +If you only wish to profile the standard actions, you can set `profiler="simple"` +when constructing your `Trainer` object. .. code-block:: python - trainer = Trainer(..., profiler=True) + trainer = Trainer(..., profiler="simple") The profiler's results will be printed at the completion of a training `fit()`. @@ -50,7 +50,7 @@ Advanced Profiling --------------------- +------------------ If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. @@ -59,6 +59,10 @@ .. code-block:: python + trainer = Trainer(..., profiler="advanced") + + or + profiler = AdvancedProfiler() trainer = Trainer(..., profiler=profiler) @@ -98,8 +102,7 @@ from pytorch_lightning.profiler import Profiler, PassThroughProfiler class MyModel(LightningModule): - def __init__(self, hparams, profiler=None): - self.hparams = hparams + def __init__(self, profiler=None): self.profiler = profiler or PassThroughProfiler() def custom_processing_step(self, data): @@ -108,16 +111,97 @@ def custom_processing_step(self, data): return data profiler = Profiler() - model = MyModel(hparams, profiler) + model = MyModel(profiler) trainer = Trainer(profiler=profiler, max_epochs=1) + +PyTorch Profiling +----------------- + +Autograd includes a profiler that lets you inspect the cost of different operators +inside your model - both on the CPU and GPU. + +To read more about the PyTorch Profiler and all its options, +have a look at its `docs `__ + +.. code-block:: python + + trainer = Trainer(..., profiler="pytorch") + + or + + profiler = PyTorchProfiler(...) + trainer = Trainer(..., profiler=profiler) + + +This profiler works with PyTorch ``DistributedDataParallel``. +If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler +report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the +output in your terminal. If no filename is given, it will be logged only on rank 0. + +The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. + +This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``, +``validation_step``, ``test_step``, and ``predict_step`` by default. +The output below shows the profiling for the action ``training_step_and_backward``. +The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. + +.. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501 + +.. code-block:: python + + Profiler Report + + Profile stats for: training_step_and_backward + --------------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg + --------------------- --------------- --------------- --------------- --------------- --------------- + t 62.10% 1.044ms 62.77% 1.055ms 1.055ms + addmm 32.32% 543.135us 32.69% 549.362us 549.362us + mse_loss 1.35% 22.657us 3.58% 60.105us 60.105us + mean 0.22% 3.694us 2.05% 34.523us 34.523us + div_ 0.64% 10.756us 1.90% 32.001us 16.000us + ones_like 0.21% 3.461us 0.81% 13.669us 13.669us + sum_out 0.45% 7.638us 0.74% 12.432us 12.432us + transpose 0.23% 3.786us 0.68% 11.393us 11.393us + as_strided 0.60% 10.060us 0.60% 10.060us 3.353us + to 0.18% 3.059us 0.44% 7.464us 7.464us + empty_like 0.14% 2.387us 0.41% 6.859us 6.859us + empty_strided 0.38% 6.351us 0.38% 6.351us 3.175us + fill_ 0.28% 4.782us 0.33% 5.566us 2.783us + expand 0.20% 3.336us 0.28% 4.743us 4.743us + empty 0.27% 4.456us 0.27% 4.456us 2.228us + copy_ 0.15% 2.526us 0.15% 2.526us 2.526us + broadcast_tensors 0.15% 2.492us 0.15% 2.492us 2.492us + size 0.06% 0.967us 0.06% 0.967us 0.484us + is_complex 0.06% 0.961us 0.06% 0.961us 0.481us + stride 0.03% 0.517us 0.03% 0.517us 0.517us + --------------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 1.681ms + +When running with `PyTorchProfiler(emit_nvtx=True)`. You should run as following:: + + nvprof --profile-from-start off -o trace_name.prof -- + +To visualize the profiled operation, you can either: + +Use:: + + nvvp trace_name.prof + +Or:: + + python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' + """ -from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler +from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler +from pytorch_lightning.profiler.pytorch import PyTorchProfiler __all__ = [ 'BaseProfiler', 'SimpleProfiler', 'AdvancedProfiler', 'PassThroughProfiler', + "PyTorchProfiler", ] diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 6f6aa959ac451f..78327fa0a91d8b 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -1,33 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" import cProfile import io +import logging import os import pstats import time from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Union import numpy as np -from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.cloud_io import get_filesystem +log = logging.getLogger(__name__) -class BaseProfiler(ABC): - """ - If you wish to write a custom profiler, you should inhereit from this class. - """ - def __init__(self, output_streams: list = None): - """ - Params: - stream_out: callable - """ - if output_streams: - if not isinstance(output_streams, (list, tuple)): - output_streams = [output_streams] - else: - output_streams = [] - self.write_streams = output_streams +class AbstractProfiler(ABC): + """Specification of a profiler.""" @abstractmethod def start(self, action_name: str) -> None: @@ -37,6 +43,48 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" + + @abstractmethod + def setup(self, **kwargs: Any) -> None: + """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" + + @abstractmethod + def teardown(self, **kwargs: Any) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + + +class BaseProfiler(AbstractProfiler): + """ + If you wish to write a custom profiler, you should inherit from this class. + """ + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + output_filename: Optional[str] = None, + ) -> None: + self.dirpath = dirpath + self.filename = filename + if output_filename is not None: + rank_zero_warn( + "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" + " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", + DeprecationWarning + ) + filepath = Path(output_filename) + self.dirpath = filepath.parent + self.filename = filepath.stem + + self._output_file: Optional[TextIO] = None + self._write_stream: Optional[Callable] = None + self._local_rank: Optional[int] = None + self._log_dir: Optional[str] = None + self._stage: Optional[str] = None + @contextmanager def profile(self, action_name: str) -> None: """ @@ -68,14 +116,94 @@ def profile_iterable(self, iterable, action_name: str) -> None: self.stop(action_name) break + def _rank_zero_info(self, *args, **kwargs) -> None: + if self._local_rank in (None, 0): + log.info(*args, **kwargs) + + def _prepare_filename(self, extension: str = ".txt") -> str: + filename = "" + if self._stage is not None: + filename += f"{self._stage}-" + filename += str(self.filename) + if self._local_rank is not None: + filename += f"-{self._local_rank}" + filename += extension + return filename + + def _prepare_streams(self) -> None: + if self._write_stream is not None: + return + if self.filename: + filepath = os.path.join(self.dirpath, self._prepare_filename()) + fs = get_filesystem(filepath) + file = fs.open(filepath, "a") + self._output_file = file + self._write_stream = file.write + else: + self._write_stream = self._rank_zero_info + def describe(self) -> None: - """Logs a profile report after the conclusion of the training run.""" - for write in self.write_streams: - write(self.summary()) + """Logs a profile report after the conclusion of run.""" + # there are pickling issues with open file handles in Python 3.6 + # so to avoid them, we open and close the files within this function + # by calling `_prepare_streams` and `teardown` + self._prepare_streams() + summary = self.summary() + if summary: + self._write_stream(summary) + if self._output_file is not None: + self._output_file.flush() + self.teardown(stage=self._stage) + + def _stats_to_str(self, stats: Dict[str, str]) -> str: + stage = f"{self._stage.upper()} " if self._stage is not None else "" + output = [stage + "Profiler Report"] + for action, value in stats.items(): + header = f"Profile stats for: {action}" + if self._local_rank is not None: + header += f" rank: {self._local_rank}" + output.append(header) + output.append(value) + return os.linesep.join(output) + + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None, + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self._stage = stage + self._local_rank = local_rank + self._log_dir = log_dir + self.dirpath = self.dirpath or log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Execute arbitrary post-profiling tear-down steps. + + Closes the currently open file and stream. + """ + self._write_stream = None + if self._output_file is not None: + self._output_file.close() + self._output_file = None # can't pickle TextIOWrapper + + def __del__(self) -> None: + self.teardown(stage=self._stage) + + def start(self, action_name: str) -> None: + raise NotImplementedError + + def stop(self, action_name: str) -> None: + raise NotImplementedError - @abstractmethod def summary(self) -> str: - """Create profiler summary in text format.""" + raise NotImplementedError + + @property + def local_rank(self) -> int: + return 0 if self._local_rank is None else self._local_rank class PassThroughProfiler(BaseProfiler): @@ -84,9 +212,6 @@ class PassThroughProfiler(BaseProfiler): The Trainer uses this class by default. """ - def __init__(self): - super().__init__(output_streams=None) - def start(self, action_name: str) -> None: pass @@ -103,63 +228,95 @@ class SimpleProfiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self, output_filename: str = None): - """ - Params: - output_filename (str): optionally save profile results to file instead of printing - to std out when training is finished. + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + extended: bool = True, + output_filename: Optional[str] = None, + ) -> None: """ - self.current_actions = {} - self.recorded_durations = defaultdict(list) + Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. - self.output_fname = output_filename - self.output_file = open(self.output_fname, 'w') if self.output_fname else None + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) + Raises: + ValueError: + If you attempt to start an action which has already started, or + if you attempt to stop recording an action which was never started. + """ + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.current_actions: Dict[str, float] = {} + self.recorded_durations = defaultdict(list) + self.extended = extended + self.start_time = time.monotonic() def start(self, action_name: str) -> None: if action_name in self.current_actions: - raise ValueError( - f"Attempted to start {action_name} which has already started." - ) + raise ValueError(f"Attempted to start {action_name} which has already started.") self.current_actions[action_name] = time.monotonic() def stop(self, action_name: str) -> None: end_time = time.monotonic() if action_name not in self.current_actions: - raise ValueError( - f"Attempting to stop recording an action ({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") start_time = self.current_actions.pop(action_name) duration = end_time - start_time self.recorded_durations[action_name].append(duration) + def _make_report(self) -> Tuple[list, float]: + total_duration = time.monotonic() - self.start_time + report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()] + report.sort(key=lambda x: x[2], reverse=True) + return report, total_duration + def summary(self) -> str: - output_string = "\n\nProfiler Report\n" + sep = os.linesep + output_string = "" + if self._stage is not None: + output_string += f"{self._stage.upper()} " + output_string += f"Profiler Report{sep}" + + if self.extended: + + if len(self.recorded_durations) > 0: + max_key = np.max([len(k) for k in self.recorded_durations.keys()]) + + def log_row(action, mean, num_calls, total, per): + row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|" + row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" + return row + + output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") + output_string_len = len(output_string) + output_string += f"{sep}{'-' * output_string_len}" + report, total_duration = self._make_report() + output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %") + output_string += f"{sep}{'-' * output_string_len}" + for action, durations, duration_per in report: + output_string += log_row( + action, + f"{np.mean(durations):.5}", + f"{len(durations):}", + f"{np.sum(durations):.5}", + f"{duration_per:.5}", + ) + else: - def log_row(action, mean, total): - return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" + def log_row(action, mean, total): + return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}" - output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"{os.linesep}{'-' * 65}" - for action, durations in self.recorded_durations.items(): - output_string += log_row( - action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}", - ) - output_string += os.linesep - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() + output_string += log_row("Action", "Mean duration (s)", "Total time (s)") + output_string += f"{sep}{'-' * 65}" - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() + for action, durations in self.recorded_durations.items(): + output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}") + output_string += sep + return output_string class AdvancedProfiler(BaseProfiler): @@ -169,24 +326,34 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__(self, output_filename: str = None, line_count_restriction: float = 1.0): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + line_count_restriction: float = 1.0, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) + + Raises: + ValueError: + If you attempt to stop recording an action which was never started. """ - self.profiled_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction - self.output_fname = output_filename - self.output_file = open(self.output_fname, 'w') if self.output_fname else None - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -195,9 +362,7 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pr = self.profiled_actions.get(action_name) if pr is None: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") pr.disable() def summary(self) -> str: @@ -207,21 +372,16 @@ def summary(self) -> str: ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) recorded_stats[action_name] = s.getvalue() + return self._stats_to_str(recorded_stats) - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" - - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() + def teardown(self, stage: Optional[str] = None) -> None: + super().teardown(stage=stage) + self.profiled_actions = {} - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() + def __reduce__(self): + # avoids `TypeError: cannot pickle 'cProfile.Profile' object` + return ( + self.__class__, + tuple(), + dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), + ) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py new file mode 100644 index 00000000000000..fa2c2917f98a2e --- /dev/null +++ b/pytorch_lightning/profiler/pytorch.py @@ -0,0 +1,512 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" +import inspect +import logging +import os +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union + +import torch +from torch import nn, Tensor +from torch.autograd.profiler import record_function + +from pytorch_lightning.profiler.profilers import BaseProfiler +from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE + +if TYPE_CHECKING: + from torch.autograd.profiler import EventList + from torch.utils.hooks import RemovableHandle + + from pytorch_lightning.core.lightning import LightningModule + +if _KINETO_AVAILABLE: + from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler + +log = logging.getLogger(__name__) + +_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] + + +class RegisterRecordFunction: + """ + While profiling autograd operations, this class will add labels for module names around the forward function. + + The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: + + Example:: + from pytorch_lightning.profilers import PyTorchProfiler + profiler = PyTorchProfiler(record_module_names=False) + Trainer(profiler=profiler) + + It can be used outside of Lightning as follows: + + Example:: + from pytorch_lightning import Trainer, seed_everything + with RegisterRecordFunction(model): + out = model(batch) + """ + + def __init__(self, model: nn.Module) -> None: + self._model = model + self._records: Dict[str, record_function] = {} + self._handles: Dict[str, List['RemovableHandle']] = {} + + def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: + record = record_function(record_name) + record.__enter__() + self._records[record_name] = record + return input + + def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: + self._records[record_name].__exit__(None, None, None) + return output + + def __enter__(self) -> None: + for module_name, module in self._model.named_modules(): + if module_name: + full_name = f"{type(module).__module__}.{type(module).__name__}" + record_name = f"{full_name}: {module_name}" + pre_forward_handle = module.register_forward_pre_hook( + partial(self._start_recording_forward, record_name=record_name) + ) + post_forward_handle = module.register_forward_hook( + partial(self._stop_recording_forward, record_name=record_name) + ) + + self._handles[module_name] = [pre_forward_handle, post_forward_handle] + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + for handles in self._handles.values(): + for h in handles: + h.remove() + self._handles = {} + + +class ScheduleWrapper: + """ + This class is used to override the schedule logic from the profiler and perform + recording for both `training_step`, `validation_step`. + """ + + def __init__(self, schedule: Callable) -> None: + if not _KINETO_AVAILABLE: + raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") + self._schedule = schedule + self.reset() + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + def reset(self): + self._num_training_step_and_backward = 0 + self._num_validation_step = 0 + self._num_test_step = 0 + self._num_predict_step = 0 + self._training_step_and_backward_reached_end = False + self._validation_step_reached_end = False + self._test_step_reached_end = False + self._predict_step_reached_end = False + # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. + self._current_action: Optional[str] = None + self._start_action_name: Optional[str] = None + + @property + def num_step(self) -> int: + if self._current_action == "training_step_and_backward": + return self._num_training_step_and_backward + elif self._current_action == "validation_step": + return self._num_validation_step + elif self._current_action == "test_step": + return self._num_test_step + elif self._current_action == "predict_step": + return self._num_predict_step + else: + return 0 + + def _step(self) -> None: + if self._current_action == "training_step_and_backward": + self._num_training_step_and_backward += 1 + elif self._current_action == "validation_step": + if self._start_action_name == "on_fit_start": + if self._num_training_step_and_backward > 0: + self._num_validation_step += 1 + else: + self._num_validation_step += 1 + elif self._current_action == "test_step": + self._num_test_step += 1 + elif self._current_action == "predict_step": + self._num_predict_step += 1 + + @property + def has_finished(self) -> bool: + if self._current_action == "training_step_and_backward": + return self._training_step_and_backward_reached_end + elif self._current_action == "validation_step": + return self._validation_step_reached_end + elif self._current_action == "test_step": + return self._test_step_reached_end + elif self._current_action == "predict_step": + return self._predict_step_reached_end + return False + + def __call__(self, num_step: int) -> 'ProfilerAction': + # ignore the provided input. Keep internal state instead. + if self.has_finished: + return ProfilerAction.NONE + + self._step() + action = self._schedule(self.num_step) + if action == ProfilerAction.RECORD_AND_SAVE: + if self._current_action == "training_step_and_backward": + self._training_step_and_backward_reached_end = True + elif self._current_action == "validation_step": + self._validation_step_reached_end = True + elif self._current_action == "test_step": + self._test_step_reached_end = True + elif self._current_action == "predict_step": + self._predict_step_reached_end = True + return action + + +class PyTorchProfiler(BaseProfiler): + + RECORD_FUNCTIONS = { + "training_step_and_backward", + "training_step", + "backward", + "validation_step", + "test_step", + "predict_step", + } + STEP_FUNCTIONS = { + "training_step_and_backward", + "validation_step", + "test_step", + "predict_step", + } + AVAILABLE_SORT_KEYS = { + "cpu_time", + "cuda_time", + "cpu_time_total", + "cuda_time_total", + "cpu_memory_usage", + "cuda_memory_usage", + "self_cpu_memory_usage", + "self_cuda_memory_usage", + "count", + } + START_RECORD_FUNCTIONS = { + 'on_fit_start', + 'on_validation_start', + 'on_test_start', + 'on_predict_start', + } + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + group_by_input_shapes: bool = False, + emit_nvtx: bool = False, + export_to_chrome: bool = True, + row_limit: int = 20, + sort_by_key: Optional[str] = None, + record_functions: Set[str] = None, + record_module_names: bool = True, + profiled_functions: Optional[List] = None, + output_filename: Optional[str] = None, + **profiler_kwargs: Any, + ) -> None: + """ + This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of + different operators inside your model - both on the CPU and GPU + + Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + + group_by_input_shapes: Include operator input shapes and group calls by shape. + + emit_nvtx: Context manager that makes every autograd operation emit an NVTX range + Run:: + + nvprof --profile-from-start off -o trace_name.prof -- + + To visualize, you can either use:: + + nvvp trace_name.prof + torch.autograd.profiler.load_nvprof(path) + + export_to_chrome: Whether to export the sequence of profiled operators for Chrome. + It will generate a ``.json`` file which can be read by Chrome. + + row_limit: Limit the number of rows in a table, ``-1`` is a special value that + removes the limit completely. + + sort_by_key: Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. + + record_functions: Set of profiled functions which will create a context manager on. + Any other will be pass through. + + record_module_names: Whether to add module names while recording autograd operation. + + profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version + + Raises: + MisconfigurationException: + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. + If arg ``schedule`` is not a ``Callable``. + If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. + """ + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + + record_functions = self.__deprecation_check(profiled_functions, record_functions) + + self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) + self._emit_nvtx = emit_nvtx + self._export_to_chrome = export_to_chrome + self._row_limit = row_limit + self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" + self._user_record_functions = record_functions + self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS + self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS + self._record_module_names = record_module_names + self._profiler_kwargs = profiler_kwargs + + self.profiler: Optional[_PROFILER] = None + self.function_events: Optional['EventList'] = None + self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector + self._register: Optional[RegisterRecordFunction] = None + self._parent_profiler: Optional[_PROFILER] = None + self._recording_map: Dict[str, record_function] = {} + self._start_action_name: Optional[str] = None + self._schedule: Optional[ScheduleWrapper] = None + + if _KINETO_AVAILABLE: + self._init_kineto(profiler_kwargs) + + if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: + raise MisconfigurationException( + f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + ) + + def _init_kineto(self, profiler_kwargs: Any) -> None: + has_schedule = "schedule" in profiler_kwargs + self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs + + schedule = profiler_kwargs.get("schedule", None) + if schedule is not None: + if not isinstance(schedule, Callable): + raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") + action = schedule(0) + if not isinstance(action, ProfilerAction): + raise MisconfigurationException( + f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" + ) + schedule = schedule if has_schedule else self._default_schedule() + self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule + self._profiler_kwargs["schedule"] = self._schedule + + activities = profiler_kwargs.get("activities", None) + self._profiler_kwargs["activities"] = activities or self._default_activities() + self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) + self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") + with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph + self._profiler_kwargs["with_stack"] = with_stack + + def __deprecation_check( + self, + profiled_functions: Optional[List[str]], + record_functions: Optional[Set[str]], + ) -> Set[str]: + if record_functions is None: + record_functions = set() + + if profiled_functions is not None: + rank_zero_warn( + "`PyTorchProfiler.profiled_functions` has been renamed to" + " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning + ) + if not record_functions: + record_functions |= set(profiled_functions) + else: + raise MisconfigurationException( + "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." + " Please use only the later." + ) + + return record_functions + + @staticmethod + def _default_schedule() -> Optional[callable]: + if _KINETO_AVAILABLE: + # Those schedule defaults allow the profiling overhead to be negligible over training time. + return torch.profiler.schedule(wait=1, warmup=1, active=3) + + def _default_activities(self) -> List['ProfilerActivity']: + activities = [] + if not _KINETO_AVAILABLE: + return activities + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + activities.append(ProfilerActivity.CUDA) + return activities + + def start(self, action_name: str) -> None: + if self.profiler is None and action_name in self._record_functions_start: + + # close profiler if it is already opened. might happen if 2 profilers + # are created and the first one did not call `describe` + try: + torch.autograd._disable_profiler() # noqa + except (AttributeError, RuntimeError): + pass + + if self._schedule is not None: + self._schedule.setup(action_name) + + self._create_profilers() + + profiler = self.profiler.__enter__() + if profiler is not None: + self.profiler = profiler + + if self._parent_profiler is not None: + self._parent_profiler.__enter__() + + if self._register is not None: + self._register.__enter__() + + if ( + self.profiler is not None and action_name in self._record_functions + and action_name not in self._recording_map + ): + recording = record_function(action_name) + recording.__enter__() + self._recording_map[action_name] = recording + + def stop(self, action_name: str) -> None: + if action_name in self._recording_map: + self._recording_map[action_name].__exit__(None, None, None) + del self._recording_map[action_name] + + if not _KINETO_AVAILABLE or self._emit_nvtx: + return + + if self.profiler is not None and action_name in self.STEP_FUNCTIONS: + if self._schedule is not None: + self._schedule.pre_step(action_name) + + def on_trace_ready(profiler): + if self.dirpath is not None: + if self._export_to_chrome: + handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension="")) + handler(profiler) + + if self._export_to_flame_graph: + path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) + profiler.export_stacks(path, metric=self._metric) + else: + rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") + + if not self._has_on_trace_ready: + self.profiler.on_trace_ready = on_trace_ready + + if self._schedule is not None: + self.profiler.step_num = self._schedule.num_step + self.profiler.step() + + def summary(self) -> str: + if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: + return "" + + self._delete_profilers() + + if not self.function_events: + return "" + + if self._export_to_chrome and not _KINETO_AVAILABLE: + filename = f"{self.local_rank}_trace.json" + path_to_trace = (filename if self.dirpath is None else os.path.join(self.dirpath, filename)) + self.function_events.export_chrome_trace(path_to_trace) + + data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) + table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) + + recorded_stats = {"records": table} + return self._stats_to_str(recorded_stats) + + def _create_profilers(self) -> None: + if self._emit_nvtx: + self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) + self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) + else: + self._parent_profiler = None + self.profiler = self._create_profiler( + torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile + ) + if self._record_module_names and self._lightning_module is not None: + self._register = RegisterRecordFunction(self._lightning_module) + + def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + init_parameters = inspect.signature(profiler.__init__).parameters + kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} + return profiler(**kwargs) + + def _cache_functions_events(self) -> None: + if self._emit_nvtx: + return + self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events + + def _delete_profilers(self) -> None: + if self.profiler is not None: + self.profiler.__exit__(None, None, None) + self._cache_functions_events() + self.profiler = None + + if self._schedule is not None: + self._schedule.reset() + + if self._parent_profiler is not None: + self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None + + if self._register is not None: + self._register.__exit__(None, None, None) + self._register = None + + def teardown(self, stage: Optional[str] = None) -> None: + self._delete_profilers() + + for k in self._recording_map: + self.stop(k) + self._recording_map = {} + + super().teardown(stage=stage) diff --git a/pytorch_lightning/pt_overrides/__init__.py b/pytorch_lightning/pt_overrides/__init__.py deleted file mode 100644 index 5e2b3ddf02c31b..00000000000000 --- a/pytorch_lightning/pt_overrides/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -.. warning:: `pt_overrides` package has been renamed to `overrides` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`pt_overrides` package has been renamed to `overrides` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py deleted file mode 100644 index 34a65e3c7ff43b..00000000000000 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -.. warning:: `override_data_parallel` module has been renamed to `data_parallel` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`override_data_parallel` module has been renamed to `data_parallel` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.overrides.data_parallel import ( # noqa: F402 - get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel) diff --git a/pytorch_lightning/py.typed b/pytorch_lightning/py.typed new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/pytorch_lightning/root_module/__init__.py b/pytorch_lightning/root_module/__init__.py deleted file mode 100644 index 41f741de6d4608..00000000000000 --- a/pytorch_lightning/root_module/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -.. warning:: `root_module` package has been renamed to `core` since v0.6.0. - The deprecated package name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module` package has been renamed to `core` since v0.6.0." - " The deprecated package name will be removed in v0.8.0.", DeprecationWarning) diff --git a/pytorch_lightning/root_module/decorators.py b/pytorch_lightning/root_module/decorators.py deleted file mode 100644 index 7031273b1d08f6..00000000000000 --- a/pytorch_lightning/root_module/decorators.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module.decorators` module has been renamed to `core.decorators` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module.decorators` module has been renamed to `core.decorators` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.decorators import * # noqa: F403 diff --git a/pytorch_lightning/root_module/grads.py b/pytorch_lightning/root_module/grads.py deleted file mode 100644 index 818114118d6000..00000000000000 --- a/pytorch_lightning/root_module/grads.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module.grads` module has been renamed to `core.grads` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module.grads` module has been renamed to `core.grads` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.grads import * # noqa: F403 diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py deleted file mode 100644 index 0214a3912d837d..00000000000000 --- a/pytorch_lightning/root_module/hooks.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module.hooks` module has been renamed to `core.hooks` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module.hooks` module has been renamed to `core.hooks` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.hooks import * # noqa: F403 diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py deleted file mode 100644 index 89d3d281172675..00000000000000 --- a/pytorch_lightning/root_module/memory.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module.memory` module has been renamed to `core.memory` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module.memory` module has been renamed to `core.memory` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.memory import * # noqa: F403 diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py deleted file mode 100644 index 67cf6a6adf1fb1..00000000000000 --- a/pytorch_lightning/root_module/model_saving.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module.model_saving` module has been renamed to `core.saving` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module.model_saving` module has been renamed to `core.saving` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.saving import * # noqa: F403 diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py deleted file mode 100644 index 3f3e9fadc733ee..00000000000000 --- a/pytorch_lightning/root_module/root_module.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -.. warning:: `root_module.root_module` module has been renamed to `core.lightning` since v0.6.0. - The deprecated module name will be removed in v0.8.0. -""" - -from pytorch_lightning.utilities import rank_zero_warn - -rank_zero_warn("`root_module.root_module` module has been renamed to `core.lightning` since v0.6.0." - " The deprecated module name will be removed in v0.8.0.", DeprecationWarning) - -from pytorch_lightning.core.lightning import * # noqa: F403 diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py new file mode 100644 index 00000000000000..3362ccb479895e --- /dev/null +++ b/pytorch_lightning/setup_tools.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import List + +_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) + + +def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: + """Load requirements from a file + + >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['numpy...', 'torch...', ...] + """ + with open(os.path.join(path_dir, file_name), 'r') as file: + lines = [ln.strip() for ln in file.readlines()] + reqs = [] + for ln in lines: + # filer all comments + if comment_char in ln: + ln = ln[:ln.index(comment_char)].strip() + # skip directly installed dependencies + if ln.startswith('http'): + continue + if ln: # if requirement is not empty + reqs.append(ln) + return reqs + + +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: + """Load readme as decribtion + + >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + '
...' + """ + path_readme = os.path.join(path_dir, "README.md") + text = open(path_readme, encoding="utf-8").read() + + # drop images from readme + text = text.replace('![PT to PL](docs/source/_static/images/general/pl_quick_start_full_compressed.gif)', '') + + # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png + github_source_url = os.path.join(homepage, "raw", version) + # replace relative repository path to absolute link to the release + # do not replace all "docs" as in the readme we reger some other sources with particular path to docs + text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") + + # readthedocs badge + text = text.replace('badge/?version=stable', f'badge/?version={version}') + text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{version}') + # codecov badge + text = text.replace('/branch/master/graph/badge.svg', f'/release/{version}/graph/badge.svg') + # replace github badges for release ones + text = text.replace('badge.svg?branch=master&event=push', f'badge.svg?tag={version}') + # Azure... + text = text.replace('?branchName=master', f'?branchName=refs%2Ftags%2F{version}') + text = re.sub(r'\?definitionId=\d+&branchName=master', f'?definitionId=2&branchName=refs%2Ftags%2F{version}', text) + + skip_begin = r'' + skip_end = r'' + # todo: wrap content as commented description + text = re.sub(rf"{skip_begin}.+?{skip_end}", '', text, flags=re.IGNORECASE + re.DOTALL) + + # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png + # github_release_url = os.path.join(homepage, "releases", "download", version) + # # download badge and replace url with local file + # text = _parse_for_badge(text, github_release_url) + return text diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 42f92979d6430d..98abd994b531dd 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -1,971 +1,21 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Once you've organized your PyTorch code into a LightningModule, -the Trainer automates everything else. - -.. figure:: /_images/lightning_module/pt_trainer.png - :alt: Convert from PyTorch to Lightning - -This abstraction achieves the following: - - 1. You maintain control over all aspects via PyTorch code without an added abstraction. - - 2. The trainer uses best practices embedded by contributors and users - from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc... - - 3. The trainer allows overriding any key part that you don't want automated. - ------------ - -Basic use ---------- - -This is the basic use of the trainer: - -.. code-block:: python - - from pytorch_lightning import Trainer - - model = MyLightningModule() - - trainer = Trainer() - trainer.fit(model) - - --------- - -Best Practices --------------- -For cluster computing, it's recommended you structure your -main.py file this way - -.. code-block:: python - - from argparse import ArgumentParser - - def main(hparams): - model = LightningModule() - trainer = Trainer(gpus=hparams.gpus) - trainer.fit(model) - - if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument('--gpus', default=None) - args = parser.parse_args() - - main(args) - -So you can run it like so:distributed_backend - -.. code-block:: bash - - python main.py --gpus 2 - - -.. note:: - If you want to stop a training run early, you can press "Ctrl + C" on your keyboard. - The trainer will catch the `KeyboardInterrupt` and attempt a graceful shutdown, including - running callbacks such as `on_train_end`. The trainer object will also set an attribute - `interrupted` to `True` in such cases. If you have a callback which shuts down compute - resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs. - ------------- - -Testing -------- -Once you're done training, feel free to run the test set! -(Only right before publishing your paper or pushing to production) - -.. code-block:: python - - trainer.test() - ------------- - -Deployment / prediction ------------------------ -You just trained a LightningModule which is also just a torch.nn.Module. -Use it to do whatever! - -.. code-block:: python - - # load model - pretrained_model = LightningModule.load_from_checkpoint(PATH) - pretrained_model.freeze() - - # use it for finetuning - def forward(self, x): - features = pretrained_model(x) - classes = classifier(features) - - # or for prediction - out = pretrained_model(x) - api_write({'response': out} - -------- - -Trainer flags -------------- - -accumulate_grad_batches -^^^^^^^^^^^^^^^^^^^^^^^ -Accumulates grads every k batches or as set up in the dict. - -.. code-block:: python - - # default used by the Trainer (no accumulation) - trainer = Trainer(accumulate_grad_batches=1) - -Example:: - - # accumulate every 4 batches (effective batch size is batch*4) - trainer = Trainer(accumulate_grad_batches=4) - - # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that - trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20}) - -amp_level -^^^^^^^^^ -The optimization level to use (O1, O2, etc...) -for 16-bit GPU precision (using NVIDIA apex under the hood). - -Check `NVIDIA apex docs `_ for level - -Example:: - - # default used by the Trainer - trainer = Trainer(amp_level='O1') - -auto_lr_find -^^^^^^^^^^^^ -Runs a learning rate finder algorithm (see this `paper `_) -before any training, to find optimal initial learning rate. - -.. code-block:: python - - # default used by the Trainer (no learning rate finder) - trainer = Trainer(auto_lr_find=False) - -Example:: - - # run learning rate finder, results override hparams.learning_rate - trainer = Trainer(auto_lr_find=True) - - # run learning rate finder, results override hparams.my_lr_arg - trainer = Trainer(auto_lr_find='my_lr_arg') - -.. note:: - See the `learning rate finder guide `_ - -benchmark -^^^^^^^^^ - -If true enables cudnn.benchmark. -This flag is likely to increase the speed of your system if your -input sizes don't change. However, if it does, then it will likely -make your system slower. - -The speedup comes from allowing the cudnn auto-tuner to find the best -algorithm for the hardware `[see discussion here] -`_. - -Example:: - - # default used by the Trainer - trainer = Trainer(benchmark=False) - -callbacks -^^^^^^^^^ - -Add a list of user defined callbacks. These callbacks DO NOT replace the explicit callbacks -(loggers, EarlyStopping or ModelCheckpoint). - -.. note:: Only user defined callbacks (ie: Not EarlyStopping or ModelCheckpoint) - -.. code-block:: python - - # a list of callbacks - callbacks = [PrintCallback()] - trainer = Trainer(callbacks=callbacks) - -Example:: - - from pytorch_lightning.callbacks import Callback - - class PrintCallback(Callback): - def on_train_start(self): - print("Training is started!") - def on_train_end(self): - print(f"Training is done. The logs are: {self.trainer.logs}") - -check_val_every_n_epoch -^^^^^^^^^^^^^^^^^^^^^^^ - -Check val every n train epochs. - -Example:: - - # default used by the Trainer - trainer = Trainer(check_val_every_n_epoch=1) - - # run val loop every 10 training epochs - trainer = Trainer(check_val_every_n_epoch=10) - -checkpoint_callback -^^^^^^^^^^^^^^^^^^^ -Callback for checkpointing. - -.. code-block:: python - - trainer = Trainer(checkpoint_callback=checkpoint_callback) - -Example:: - - from pytorch_lightning.callbacks import ModelCheckpoint - - # default used by the Trainer - checkpoint_callback = ModelCheckpoint( - filepath=os.getcwd(), - save_top_k=True, - verbose=True, - monitor='val_loss', - mode='min', - prefix='' - ) - -default_root_dir -^^^^^^^^^^^^^^^^^ - -Default path for logs and weights when no logger -or :class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. -On certain clusters you might want to separate where logs and checkpoints -are stored. If you don't then use this method for convenience. - -Example:: - - # default used by the Trainer - trainer = Trainer(default_root_path=os.getcwd()) - -distributed_backend -^^^^^^^^^^^^^^^^^^^ -The distributed backend to use. - -- (```dp```) is DataParallel (split batch among GPUs of same machine) -- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads) -- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs. - Useful for multi-node CPU training or single-node debugging. Note that this will **not** give - a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single - machine.) -- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing - the number of negative samples - -.. code-block:: python - - # default used by the Trainer - trainer = Trainer(distributed_backend=None) - -Example:: - - # dp = DataParallel - trainer = Trainer(gpus=2, distributed_backend='dp') - - # ddp = DistributedDataParallel - trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp') - - # ddp2 = DistributedDataParallel + dp - trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2') - -.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core) - -early_stop_callback -^^^^^^^^^^^^^^^^^^^ - -Callback for early stopping. -early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`) - -- ``True``: A default callback monitoring ``'val_loss'`` is created. - Will raise an error if ``'val_loss'`` is not found. -- ``False``: Early stopping will be disabled. -- ``None``: The default callback monitoring ``'val_loss'`` is created. -- Default: ``None``. - -.. code-block:: python - - trainer = Trainer(early_stop_callback=early_stop_callback) - -Example:: - - from pytorch_lightning.callbacks import EarlyStopping - - # default used by the Trainer - early_stop_callback = EarlyStopping( - monitor='val_loss', - patience=3, - strict=False, - verbose=False, - mode='min' - ) - -.. note:: If ``'val_loss'`` is not found will work as if early stopping is disabled. - -fast_dev_run -^^^^^^^^^^^^ - -Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). - -Under the hood the pseudocode looks like this: - -.. code-block:: python - - # loading - __init__() - prepare_data - - # test training step - training_batch = next(train_dataloader) - training_step(training_batch) - - # test val step - val_batch = next(val_dataloader) - out = validation_step(val_batch) - validation_epoch_end([out]) - -Example:: - - # default used by the Trainer - trainer = Trainer(fast_dev_run=False) - - # runs 1 train, val, test batch and program ends - trainer = Trainer(fast_dev_run=True) - -gpus -^^^^ - -- Number of GPUs to train on -- or Which GPUs to train on -- can handle strings - -Example:: - - # default used by the Trainer (ie: train on CPU) - trainer = Trainer(gpus=None) - - # int: train on 2 gpus - trainer = Trainer(gpus=2) - - # list: train on GPUs 1, 4 (by bus ordering) - trainer = Trainer(gpus=[1, 4]) - trainer = Trainer(gpus='1, 4') # equivalent - - # -1: train on all gpus - trainer = Trainer(gpus=-1) - trainer = Trainer(gpus='-1') # equivalent - - # combine with num_nodes to train on multiple GPUs across nodes - # uses 8 gpus in total - trainer = Trainer(gpus=2, num_nodes=4) - -.. note:: See the `multi-gpu computing guide `_ - -gradient_clip_val -^^^^^^^^^^^^^^^^^ -Gradient clipping value - -- 0 means don't clip. - -Example:: - - # default used by the Trainer - trainer = Trainer(gradient_clip_val=0.0) - - -gradient_clip: - -.. warning:: .. deprecated:: 0.5.0 - - Use `gradient_clip_val` instead. Will remove 0.8.0. - -log_gpu_memory -^^^^^^^^^^^^^^ -Options: - -- None -- 'min_max' -- 'all' - -Example:: - - # default used by the Trainer - trainer = Trainer(log_gpu_memory=None) - - # log all the GPUs (on master node only) - trainer = Trainer(log_gpu_memory='all') - - # log only the min and max memory on the master node - trainer = Trainer(log_gpu_memory='min_max') - -.. note:: Might slow performance because it uses the output of nvidia-smi. - -log_save_interval -^^^^^^^^^^^^^^^^^ - -Writes logs to disk this often. - -Example:: - - # default used by the Trainer - trainer = Trainer(log_save_interval=100) - -logger -^^^^^^ - -`Logger `_ (or iterable collection of loggers) for experiment tracking. - -.. code-block:: python - - Trainer(logger=logger) - -Example:: - - from pytorch_lightning.loggers import TensorBoardLogger - - # default logger used by trainer - logger = TensorBoardLogger( - save_dir=os.getcwd(), - version=self.slurm_job_id, - name='lightning_logs' - ) - -max_epochs -^^^^^^^^^^ -Stop training once this number of epochs is reached - -Example:: - - # default used by the Trainer - trainer = Trainer(max_epochs=1000) - -max_nb_epochs: - -.. warning:: .. deprecated:: 0.5.0 - - Use `max_epochs` instead. Will remove 0.8.0. - -min_epochs -^^^^^^^^^^ -Force training for at least these many epochs - -Example:: - - # default used by the Trainer - trainer = Trainer(min_epochs=1) - -min_nb_epochs: - -.. warning:: deprecated:: 0.5.0 - Use `min_epochs` instead. Will remove 0.8.0. - -max_steps -^^^^^^^^^ -Stop training after this number of steps -Training will stop if max_steps or max_epochs have reached (earliest). - -.. code-block:: python - - # Default (disabled) - trainer = Trainer(max_steps=None) - -Example:: - - # Stop after 100 steps - trainer = Trainer(max_steps=100) - -min_steps -^^^^^^^^^ - -Force training for at least these number of steps. -Trainer will train model for at least min_steps or min_epochs (latest). - -.. code-block:: python - - # Default (disabled) - trainer = Trainer(min_steps=None) - -Example:: - - # Run at least for 100 steps (disable min_epochs) - trainer = Trainer(min_steps=100, min_epochs=0) - -num_nodes -^^^^^^^^^ - -Number of GPU nodes for distributed training. - -Example:: - - # default used by the Trainer - trainer = Trainer(num_nodes=1) - - # to train on 8 nodes - trainer = Trainer(num_nodes=8) - -nb_gpu_nodes: - -.. warning:: .. deprecated:: 0.5.0 - - Use `num_nodes` instead. Will remove 0.8.0. - -num_processes -^^^^^^^^^^^^^ - -Number of processes to train with. Automatically set to the number of GPUs -when using ``distrbuted_backend="ddp"``. Set to a number greater than 1 when -using ``distributed_backend="ddp_cpu"`` to mimic distributed training on a -machine without GPUs. This is useful for debugging, but **will not** provide -any speedup, since single-process Torch already makes effient use of multiple -CPUs. - -Example:: - - # Simulate DDP for debugging on your GPU-less laptop - trainer = Trainer(distributed_backend="ddp_cpu", num_processes=2) - -num_sanity_val_steps -^^^^^^^^^^^^^^^^^^^^ - -Sanity check runs n batches of val before starting the training routine. -This catches any bugs in your validation without having to wait for the first validation check. -The Trainer uses 5 steps by default. Turn it off or modify it here. - -Example:: - - # default used by the Trainer - trainer = Trainer(num_sanity_val_steps=5) - - # turn it off - trainer = Trainer(num_sanity_val_steps=0) - -nb_sanity_val_steps: - -.. warning:: .. deprecated:: 0.5.0 - - Use `num_sanity_val_steps` instead. Will remove 0.8.0. - -num_tpu_cores -^^^^^^^^^^^^^ -How many TPU cores to train on (1 or 8). - -A single TPU v2 or v3 has 8 cores. A TPU pod has -up to 2048 cores. A slice of a POD means you get as many cores -as you request. - -Your effective batch size is batch_size * total tpu cores. - -.. note:: No need to add a DistributedDataSampler, Lightning automatically does it for you. - -This parameter can be either 1 or 8. - -Example:: - - # your_trainer_file.py - - # default used by the Trainer (ie: train on CPU) - trainer = Trainer(num_tpu_cores=None) - - # int: train on a single core - trainer = Trainer(num_tpu_cores=1) - - # int: train on all cores few cores - trainer = Trainer(num_tpu_cores=8) - - # for 8+ cores must submit via xla script with - # a max of 8 cores specified. The XLA script - # will duplicate script onto each TPU in the POD - trainer = Trainer(num_tpu_cores=8) - - # -1: train on all available TPUs - trainer = Trainer(num_tpu_cores=-1) - -To train on more than 8 cores (ie: a POD), -submit this script using the xla_dist script. - -Example:: - - python -m torch_xla.distributed.xla_dist - --tpu=$TPU_POD_NAME - --conda-env=torch-xla-nightly - --env=XLA_USE_BF16=1 - -- python your_trainer_file.py - -overfit_pct -^^^^^^^^^^^ -Uses this much data of all datasets (training, validation, test). -Useful for quickly debugging or trying to overfit on purpose. - -Example:: - - # default used by the Trainer - trainer = Trainer(overfit_pct=0.0) - - # use only 1% of the train, test, val datasets - trainer = Trainer(overfit_pct=0.01) - - # equivalent: - trainer = Trainer( - train_percent_check=0.01, - val_percent_check=0.01, - test_percent_check=0.01 - ) - -See Also: - - `train_percent_check`_ - - `val_percent_check`_ - - `test_percent_check`_ - - -precision -^^^^^^^^^ -Full precision (32), half precision (16). -Can be used on CPU, GPU or TPUs. - -If used on TPU will use torch.bfloat16 but tensor printing -will still show torch.float32. - -Example:: - - # default used by the Trainer - trainer = Trainer(precision=32) - - # 16-bit precision - trainer = Trainer(precision=16) - - # one day - trainer = Trainer(precision=8|4|2) - -print_nan_grads -^^^^^^^^^^^^^^^ - -.. warning:: .. deprecated:: 0.7.2. - - Has no effect. When detected, NaN grads will be printed automatically. - Will remove 0.9.0. - - -process_position -^^^^^^^^^^^^^^^^ -Orders the progress bar. Useful when running multiple trainers on the same node. - -Example:: - - # default used by the Trainer - trainer = Trainer(process_position=0) - -Note: - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. - -profiler -^^^^^^^^ -To profile individual steps during training and assist in identifying bottlenecks. - -See the `profiler documentation `_. for more details. - -Example:: - - from pytorch_lightning.profiler import Profiler, AdvancedProfiler - - # default used by the Trainer - trainer = Trainer(profiler=None) - - # to profile standard training events - trainer = Trainer(profiler=True) - - # equivalent to profiler=True - profiler = Profiler() - trainer = Trainer(profiler=profiler) - - # advanced profiler for function-level stats - profiler = AdvancedProfiler() - trainer = Trainer(profiler=profiler) - -progress_bar_refresh_rate -^^^^^^^^^^^^^^^^^^^^^^^^^ -How often to refresh progress bar (in steps). -In notebooks, faster refresh rates (lower number) is known to crash them -because of their screen refresh rates, so raise it to 50 or more. - -Example:: - - # default used by the Trainer - trainer = Trainer(progress_bar_refresh_rate=1) - - # disable progress bar - trainer = Trainer(progress_bar_refresh_rate=0) - -Note: - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. - -reload_dataloaders_every_epoch -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Set to True to reload dataloaders every epoch. - -.. code-block:: python - - # if False (default) - train_loader = model.train_dataloader() - for epoch in epochs: - for batch in train_loader: - ... - - # if True - for epoch in epochs: - train_loader = model.train_dataloader() - for batch in train_loader: - -replace_sampler_ddp -^^^^^^^^^^^^^^^^^^^ -Enables auto adding of distributed sampler. - -Example:: - - # default used by the Trainer - trainer = Trainer(replace_sampler_ddp=True) - -By setting to False, you have to add your own distributed sampler: - -Example:: - - # default used by the Trainer - sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) - dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) - -resume_from_checkpoint -^^^^^^^^^^^^^^^^^^^^^^ -To resume training from a specific checkpoint pass in the path here. - -Example:: - - # default used by the Trainer - trainer = Trainer(resume_from_checkpoint=None) - - # resume from a specific checkpoint - trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') - -row_log_interval -^^^^^^^^^^^^^^^^ - -How often to add logging rows (does not write to disk) - -Example:: - - # default used by the Trainer - trainer = Trainer(row_log_interval=10) - - -add_row_log_interval: - -.. warning:: .. deprecated:: 0.5.0 - - Use `row_log_interval` instead. Will remove 0.8.0. - -use_amp: - -.. warning:: .. deprecated:: 0.7.0 - - Use `precision` instead. Will remove 0.9.0. - -show_progress_bar -^^^^^^^^^^^^^^^^^ - -.. warning:: .. deprecated:: 0.7.2 - - Set `progress_bar_refresh_rate` to 0 instead. Will remove 0.9.0. - -test_percent_check -^^^^^^^^^^^^^^^^^^ - -How much of test dataset to check. - -Example:: - - # default used by the Trainer - trainer = Trainer(test_percent_check=1.0) - - # run through only 25% of the test set each epoch - trainer = Trainer(test_percent_check=0.25) - -val_check_interval -^^^^^^^^^^^^^^^^^^ - -How often within one training epoch to check the validation set. -Can specify as float or int. - -- use (float) to check within a training epoch -- use (int) to check every n steps (batches) - -.. code-block:: python - - # default used by the Trainer - trainer = Trainer(val_check_interval=1.0) - -Example:: - - # check validation set 4 times during a training epoch - trainer = Trainer(val_check_interval=0.25) - - # check validation set every 1000 training batches - # use this when using iterableDataset and your dataset has no length - # (ie: production cases with streaming data) - trainer = Trainer(val_check_interval=1000) - -track_grad_norm -^^^^^^^^^^^^^^^ - -- no tracking (-1) -- Otherwise tracks that norm (2 for 2-norm) - -.. code-block:: python - - # default used by the Trainer - trainer = Trainer(track_grad_norm=-1) - -Example:: - - # track the 2-norm - trainer = Trainer(track_grad_norm=2) - -train_percent_check -^^^^^^^^^^^^^^^^^^^ - -How much of training dataset to check. -Useful when debugging or testing something that happens at the end of an epoch. - -.. code-block::python - - # default used by the Trainer - trainer = Trainer(train_percent_check=1.0) - -Example:: - - # default used by the Trainer - trainer = Trainer(train_percent_check=1.0) - - # run through only 25% of the training set each epoch - trainer = Trainer(train_percent_check=0.25) - -truncated_bptt_steps -^^^^^^^^^^^^^^^^^^^^ - -Truncated back prop breaks performs backprop every k steps of -a much longer sequence. - -If this is enabled, your batches will automatically get truncated -and the trainer will apply Truncated Backprop to it. - -(`Williams et al. "An efficient gradient-based algorithm for on-line training of -recurrent network trajectories." -`_) - -Example:: - - # default used by the Trainer (ie: disabled) - trainer = Trainer(truncated_bptt_steps=None) - - # backprop every 5 steps in a batch - trainer = Trainer(truncated_bptt_steps=5) - -.. note:: Make sure your batches have a sequence dimension. - -Lightning takes care to split your batch along the time-dimension. - -.. code-block:: python - - # we use the second as the time dimension - # (batch, time, ...) - sub_batch = batch[0, 0:t, ...] - -Using this feature requires updating your LightningModule's -:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg -with the hidden - -.. code-block:: python - - # Truncated back-propagation through time - def training_step(self, batch, batch_idx, hiddens): - # hiddens are the hiddens from the previous truncated backprop step - out, hiddens = self.lstm(data, hiddens) - - return { - "loss": ..., - "hiddens": hiddens # remember to detach() this - } - -To modify how the batch is split, -override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: - -.. code-block:: python - - class LitMNIST(pl.LightningModule): - def tbptt_split_batch(self, batch, split_size): - # do your own splitting on the batch - return splits - - -val_percent_check -^^^^^^^^^^^^^^^^^ - -How much of validation dataset to check. -Useful when debugging or testing something that happens at the end of an epoch. - -Example:: - - # default used by the Trainer - trainer = Trainer(val_percent_check=1.0) - - # run through only 25% of the validation set each epoch - trainer = Trainer(val_percent_check=0.25) - -weights_save_path -^^^^^^^^^^^^^^^^^ -Directory of where to save weights if specified. - -.. code-block:: python - - # default used by the Trainer - trainer = Trainer(weights_save_path=os.getcwd()) - -Example:: - - # save to your custom path - trainer = Trainer(weights_save_path='my/path') - - # if checkpoint callback used, then overrides the weights path - # **NOTE: this saves weights to some/path NOT my/path - checkpoint_callback = ModelCheckpoint(filepath='some/path') - trainer = Trainer( - checkpoint_callback=checkpoint_callback, - weights_save_path='my/path' - ) - -weights_summary -^^^^^^^^^^^^^^^ -Prints a summary of the weights when training begins. -Options: 'full', 'top', None. - -Example:: - - # default used by the Trainer (ie: print all weights) - trainer = Trainer(weights_summary='full') - - # print only the top level modules - trainer = Trainer(weights_summary='top') - - # don't print a summary - trainer = Trainer(weights_summary=None) - -Trainer class -------------- """ from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.seed import seed_everything -__all__ = ['Trainer'] +__all__ = ["Trainer", "seed_everything"] diff --git a/pytorch_lightning/trainer/auto_mix_precision.py b/pytorch_lightning/trainer/auto_mix_precision.py deleted file mode 100644 index 2551b8a22dd0ff..00000000000000 --- a/pytorch_lightning/trainer/auto_mix_precision.py +++ /dev/null @@ -1,56 +0,0 @@ -from abc import ABC -import torch - -from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_warn - -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True - - -class TrainerAMPMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - precision: int - use_native_amp: bool - - def init_amp(self, use_amp): - # TODO: remove in v 0.8.0 - if self.use_native_amp: - rank_zero_warn("`amp_level` has been deprecated since v0.7.4 " - "(native amp does not require it)" - " and this argument will be removed in v0.8.0", DeprecationWarning) - - # Backward compatibility, TODO: remove in v0.9.0 - if use_amp is not None: - rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0" - " and this argument will be removed in v0.9.0", DeprecationWarning) - self.precision = 16 if use_amp else 32 - - assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' - - if use_amp and self.use_native_amp: - log.info('Using 16bit precision.') - return - - # TODO: remove all below for v0.8.0 - if use_amp and not APEX_AVAILABLE: # pragma: no-cover - raise ModuleNotFoundError(""" - You set `use_amp=True` but do not have apex installed. - Install apex first using this guide and rerun with use_amp=True: - https://github.com/NVIDIA/apex#linux - - this run will NOT use 16 bit precision - """) - - if self.use_amp: - log.info('Using 16bit precision.') - - @property - def use_amp(self) -> bool: - return self.precision == 16 diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py deleted file mode 100644 index a760b9760209cb..00000000000000 --- a/pytorch_lightning/trainer/callback_config.py +++ /dev/null @@ -1,128 +0,0 @@ -import os -from abc import ABC, abstractmethod -from typing import Union, List - - -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class TrainerCallbackConfigMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - callbacks: List[Callback] - default_root_dir: str - logger: Union[LightningLoggerBase, bool] - weights_save_path: str - ckpt_path: str - checkpoint_callback: ModelCheckpoint - progress_bar_refresh_rate: int - process_position: int - - @property - @abstractmethod - def slurm_job_id(self) -> int: - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def save_checkpoint(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def configure_checkpoint_callback(self): - """ - Weight path set in this priority: - Checkpoint_callback's path (if passed in). - User provided weights_saved_path - Otherwise use os.getcwd() - """ - ckpt_path = self.default_root_dir - if self.checkpoint_callback: - # init a default one - if self.logger is not None: - save_dir = (getattr(self.logger, 'save_dir', None) or - getattr(self.logger, '_save_dir', None) or - self.default_root_dir) - - # weights_save_path overrides anything - if self.weights_save_path is not None: - save_dir = self.weights_save_path - - version = self.logger.version if isinstance( - self.logger.version, str) else f'version_{self.logger.version}' - ckpt_path = os.path.join( - save_dir, - self.logger.name, - version, - "checkpoints" - ) - else: - ckpt_path = os.path.join(self.default_root_dir, "checkpoints") - - # when no val step is defined, use 'loss' otherwise 'val_loss' - train_step_only = not self.is_overriden('validation_step') - monitor_key = 'loss' if train_step_only else 'val_loss' - - if self.checkpoint_callback is True: - os.makedirs(ckpt_path, exist_ok=True) - self.checkpoint_callback = ModelCheckpoint( - filepath=ckpt_path, - monitor=monitor_key - ) - # If user specified None in filepath, override with runtime default - elif isinstance(self.checkpoint_callback, ModelCheckpoint) \ - and self.checkpoint_callback.dirpath is None: - self.checkpoint_callback.dirpath = ckpt_path - self.checkpoint_callback.filename = '{epoch}' - os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True) - elif self.checkpoint_callback is False: - self.checkpoint_callback = None - - self.ckpt_path = ckpt_path - - if self.checkpoint_callback: - # set the path for the callbacks - self.checkpoint_callback.save_function = self.save_checkpoint - - # if checkpoint callback used, then override the weights path - self.weights_save_path = self.checkpoint_callback.dirpath - - # if weights_save_path is still none here, set to current working dir - if self.weights_save_path is None: - self.weights_save_path = self.default_root_dir - - def configure_early_stopping(self, early_stop_callback): - if early_stop_callback is True or None: - self.early_stop_callback = EarlyStopping( - monitor='val_loss', - patience=3, - strict=True, - verbose=True, - mode='min' - ) - self.enable_early_stop = True - elif not early_stop_callback: - self.early_stop_callback = None - self.enable_early_stop = False - else: - self.early_stop_callback = early_stop_callback - self.enable_early_stop = True - - def configure_progress_bar(self): - progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] - if len(progress_bars) > 1: - raise MisconfigurationException( - 'You added multiple progress bar callbacks to the Trainer, but currently only one' - ' progress bar is supported.' - ) - elif len(progress_bars) == 1: - self.progress_bar_callback = progress_bars[0] - elif self.progress_bar_refresh_rate > 0: - self.progress_bar_callback = ProgressBar( - refresh_rate=self.progress_bar_refresh_rate, - process_position=self.process_position, - ) - self.callbacks.append(self.progress_bar_callback) - else: - self.progress_bar_callback = None diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 37f56e69410393..606f6b2e4b52bf 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -1,16 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC -from typing import Callable, List +from copy import deepcopy +from inspect import signature +from typing import Any, Callable, Dict, List, Optional, Type from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class TrainerCallbackHookMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.callbacks: List[Callback] = [] - self.get_model: Callable = ... + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + callbacks: List[Callback] = [] + lightning_module: LightningModule + + def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_before_accelerator_backend_setup(self, model) + + def configure_sharded_model(self, model: LightningModule) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_configure_sharded_model(self, model) + + def setup(self, model: LightningModule, stage: Optional[str]) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.setup(self, model, stage) + + def teardown(self, stage: Optional[str] = None) -> None: + """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.teardown(self, self.lightning_module, stage) def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" @@ -22,82 +63,226 @@ def on_init_end(self): for callback in self.callbacks: callback.on_init_end(self) + def on_fit_start(self): + """Called when the trainer initialization begins, model has not yet been set.""" + for callback in self.callbacks: + callback.on_fit_start(self, self.lightning_module) + + def on_fit_end(self): + """Called when the trainer initialization begins, model has not yet been set.""" + for callback in self.callbacks: + callback.on_fit_end(self, self.lightning_module) + def on_sanity_check_start(self): """Called when the validation sanity check starts.""" for callback in self.callbacks: - callback.on_sanity_check_start(self, self.get_model()) + callback.on_sanity_check_start(self, self.lightning_module) def on_sanity_check_end(self): """Called when the validation sanity check ends.""" for callback in self.callbacks: - callback.on_sanity_check_end(self, self.get_model()) + callback.on_sanity_check_end(self, self.lightning_module) - def on_epoch_start(self): + def on_train_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_train_epoch_start(self, self.lightning_module) + + def on_train_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``train`` epoch + """ + for callback in self.callbacks: + callback.on_train_epoch_end(self, self.lightning_module, outputs) + + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: - callback.on_epoch_start(self, self.get_model()) + callback.on_validation_epoch_start(self, self.lightning_module) + + def on_validation_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``validation`` epoch + """ + for callback in self.callbacks: + if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): + callback.on_validation_epoch_end(self, self.lightning_module, outputs) + else: + warning_cache.warn( + "`Callback.on_validation_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_validation_epoch_end(self, self.lightning_module) + + def on_test_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_test_epoch_start(self, self.lightning_module) + + def on_test_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``test`` epoch + """ + for callback in self.callbacks: + if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): + callback.on_test_epoch_end(self, self.lightning_module, outputs) + else: + warning_cache.warn( + "`Callback.on_test_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_test_epoch_end(self, self.lightning_module) + + def on_epoch_start(self): + """Called when either of train/val/test epoch begins.""" + for callback in self.callbacks: + callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" for callback in self.callbacks: - callback.on_epoch_end(self, self.get_model()) + callback.on_epoch_end(self, self.lightning_module) def on_train_start(self): """Called when the train begins.""" for callback in self.callbacks: - callback.on_train_start(self, self.get_model()) + callback.on_train_start(self, self.lightning_module) def on_train_end(self): """Called when the train ends.""" for callback in self.callbacks: - callback.on_train_end(self, self.get_model()) + callback.on_train_end(self, self.lightning_module) + + def on_pretrain_routine_start(self) -> None: + """Called when the pre-train routine begins.""" + for callback in self.callbacks: + callback.on_pretrain_routine_start(self, self.lightning_module) + + def on_pretrain_routine_end(self) -> None: + """Called when the pre-train routine ends.""" + for callback in self.callbacks: + callback.on_pretrain_routine_end(self, self.lightning_module) def on_batch_start(self): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_batch_start(self, self.get_model()) + callback.on_batch_start(self, self.lightning_module) def on_batch_end(self): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_batch_end(self, self.get_model()) + callback.on_batch_end(self, self.lightning_module) + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) - def on_validation_batch_start(self): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" for callback in self.callbacks: - callback.on_validation_batch_start(self, self.get_model()) + callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) - def on_validation_batch_end(self): + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the validation batch ends.""" for callback in self.callbacks: - callback.on_validation_batch_end(self, self.get_model()) + callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) - def on_test_batch_start(self): + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the test batch begins.""" for callback in self.callbacks: - callback.on_test_batch_start(self, self.get_model()) + callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) - def on_test_batch_end(self): + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Called when the test batch ends.""" for callback in self.callbacks: - callback.on_test_batch_end(self, self.get_model()) + callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) def on_validation_start(self): """Called when the validation loop begins.""" for callback in self.callbacks: - callback.on_validation_start(self, self.get_model()) + callback.on_validation_start(self, self.lightning_module) def on_validation_end(self): """Called when the validation loop ends.""" for callback in self.callbacks: - callback.on_validation_end(self, self.get_model()) + callback.on_validation_end(self, self.lightning_module) def on_test_start(self): """Called when the test begins.""" for callback in self.callbacks: - callback.on_test_start(self, self.get_model()) + callback.on_test_start(self, self.lightning_module) def on_test_end(self): """Called when the test ends.""" for callback in self.callbacks: - callback.on_test_end(self, self.get_model()) + callback.on_test_end(self, self.lightning_module) + + def on_keyboard_interrupt(self): + """Called when the training is interrupted by KeyboardInterrupt.""" + for callback in self.callbacks: + callback.on_keyboard_interrupt(self, self.lightning_module) + + @staticmethod + def __is_old_signature(fn: Callable) -> bool: + parameters = list(signature(fn).parameters) + if len(parameters) == 2 and parameters[1] != "args": + return True + return False + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: + """Called when saving a model checkpoint.""" + callback_states = {} + for callback in self.callbacks: + if self.__is_old_signature(callback.on_save_checkpoint): + rank_zero_deprecation( + "`Callback.on_save_checkpoint` signature has changed in v1.3." + " A `checkpoint` parameter has been added." + " Support for the old signature will be removed in v1.5" + ) + state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled + else: + state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) + if state: + callback_states[type(callback)] = state + return callback_states + + def on_load_checkpoint(self, checkpoint): + """Called when loading a model checkpoint.""" + callback_states = checkpoint.get('callbacks') + # Todo: the `callback_states` are dropped with TPUSpawn as they + # can't be saved using `xm.save` + # https://github.com/pytorch/xla/issues/2773 + if callback_states is not None: + for callback in self.callbacks: + state = callback_states.get(type(callback)) + if state: + state = deepcopy(state) + callback.on_load_checkpoint(state) + + def on_after_backward(self): + """ + Called after loss.backward() and before optimizers do anything. + """ + for callback in self.callbacks: + callback.on_after_backward(self, self.lightning_module) + + def on_before_zero_grad(self, optimizer): + """ + Called after optimizer.step() and before optimizer.zero_grad(). + """ + for callback in self.callbacks: + callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py new file mode 100644 index 00000000000000..a7ba2b1c401232 --- /dev/null +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden + + +class ConfigValidator(object): + + def __init__(self, trainer): + self.trainer = trainer + + def verify_loop_configurations(self, model: LightningModule) -> None: + r""" + Checks that the model is configured correctly before the run is started. + + Args: + model: The model to check the configuration. + + """ + if self.trainer.state == TrainerState.FITTING: + self.__verify_train_loop_configuration(model) + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TUNING: + self.__verify_train_loop_configuration(model) + elif self.trainer.state == TrainerState.VALIDATING: + self.__verify_eval_loop_configuration(model, 'val') + elif self.trainer.state == TrainerState.TESTING: + self.__verify_eval_loop_configuration(model, 'test') + elif self.trainer.state == TrainerState.PREDICTING: + self.__verify_predict_loop_configuration(model) + + def __verify_train_loop_configuration(self, model): + # ----------------------------------- + # verify model has a training step + # ----------------------------------- + has_training_step = is_overridden('training_step', model) + if not has_training_step: + raise MisconfigurationException( + 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' + ) + + # ----------------------------------- + # verify model has a train dataloader + # ----------------------------------- + has_train_dataloader = is_overridden('train_dataloader', model) + if not has_train_dataloader: + raise MisconfigurationException( + 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' + ) + + # ----------------------------------- + # verify model has optimizer + # ----------------------------------- + has_optimizers = is_overridden('configure_optimizers', model) + if not has_optimizers: + raise MisconfigurationException( + 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' + ) + + trainer = self.trainer + + trainer.overriden_optimizer_step = is_overridden('optimizer_step', model) + trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model) + automatic_optimization = trainer.train_loop.automatic_optimization + going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() + + has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad + if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: + raise MisconfigurationException( + 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,' + ' `accumulate_grad_batches` in `Trainer` should be 1.' + ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' + ) + + def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: + loader_name = f'{stage}_dataloader' + step_name = 'validation_step' if stage == 'val' else 'test_step' + + has_loader = is_overridden(loader_name, model) + has_step = is_overridden(step_name, model) + + if has_loader and not has_step: + rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') + if has_step and not has_loader: + rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') + + def __verify_predict_loop_configuration(self, model: LightningModule) -> None: + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/pytorch_lightning/trainer/connectors/__init__.py b/pytorch_lightning/trainer/connectors/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py new file mode 100644 index 00000000000000..30d2b48975a84c --- /dev/null +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -0,0 +1,644 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import List, Optional, Sequence, Union + +import torch + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.tpu import TPUAccelerator +from pytorch_lightning.plugins import ( + ApexMixedPrecisionPlugin, + DataParallelPlugin, + DDP2Plugin, + DDPPlugin, + DDPShardedPlugin, + DDPSpawnPlugin, + DDPSpawnShardedPlugin, + DeepSpeedPlugin, + DeepSpeedPrecisionPlugin, + DoublePrecisionPlugin, + HorovodPlugin, + NativeMixedPrecisionPlugin, + PrecisionPlugin, + ShardedNativeMixedPrecisionPlugin, + SingleDevicePlugin, + SingleTPUPlugin, + TPUHalfPrecisionPlugin, + TPUSpawnPlugin, + TrainingTypePlugin, +) +from pytorch_lightning.plugins.environments import ( + ClusterEnvironment, + LightningEnvironment, + SLURMEnvironment, + TorchElasticEnvironment, +) +from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus +from pytorch_lightning.utilities import ( + _APEX_AVAILABLE, + _HOROVOD_AVAILABLE, + _NATIVE_AMP_AVAILABLE, + _TPU_AVAILABLE, + AMPType, + device_parser, + DeviceType, + DistributedType, + rank_zero_only, +) +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + +log = logging.getLogger(__name__) + + +class AcceleratorConnector(object): + + def __init__( + self, + num_processes, + tpu_cores, + distributed_backend, + auto_select_gpus, + gpus, + num_nodes, + sync_batchnorm, + benchmark, + replace_sampler_ddp, + deterministic, + precision, + amp_type, + amp_level, + plugins, + ): + # initialization + self._device_type = DeviceType.CPU + self._distrib_type = None + + self.num_processes = num_processes + self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) + self.distributed_backend = distributed_backend + self.auto_select_gpus = auto_select_gpus + self.gpus = gpus + self.num_nodes = num_nodes + self.sync_batchnorm = sync_batchnorm + self.benchmark = benchmark + self.replace_sampler_ddp = replace_sampler_ddp + self.deterministic = deterministic + self.precision = precision + self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None + self.amp_level = amp_level + self.is_slurm_managing_tasks = False + + self._precision_plugin: Optional[PrecisionPlugin] = None + self._training_type_plugin: Optional[TrainingTypePlugin] = None + self._cluster_environment: Optional[ClusterEnvironment] = None + + # init the default rank if exists + # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks + # this way we only show it on rank 0 + if "LOCAL_RANK" in os.environ: + rank_zero_only.rank = int(os.environ["LOCAL_RANK"]) + + # for gpus allow int, string and gpu list + if auto_select_gpus and isinstance(gpus, int): + self.gpus = pick_multiple_gpus(gpus) + + self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) + + self.set_distributed_mode() + self.configure_slurm_ddp() + + self.handle_given_plugins(plugins) + + self.accelerator = self.select_accelerator() + + # override dist backend when using tpus + if self.on_tpu: + self.distributed_backend = "tpu" + + # init flags for SLURM+DDP to work + self.world_size = 1 + self.interactive_ddp_procs = [] + self.global_rank = 0 + + # benchmarking + # TODO: should this be moved to GPU accelerator? + torch.backends.cudnn.benchmark = self.benchmark + + # determinism for cudnn + # TODO: should this be moved to GPU accelerator? + torch.backends.cudnn.deterministic = deterministic + if deterministic: + # fixing non-deterministic part of horovod + # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 + os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) + + self.replace_sampler_ddp = replace_sampler_ddp + + def handle_given_plugins( + self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]] + ): + plugins = plugins if plugins is not None else [] + + if isinstance(plugins, str): + plugins = [plugins] + + if not isinstance(plugins, Sequence): + plugins = [plugins] + + training_type = None + precision = None + cluster_environment = None + + for plug in plugins: + if isinstance(plug, str): + # Reset the distributed type as the user has overridden training type + # via the plugins argument + self._distrib_type = None + self.set_distributed_mode(plug) + + elif isinstance(plug, TrainingTypePlugin): + if training_type is None: + training_type = plug + + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin.' + f' Found more than 1 training type plugin: {type(plug).__name__}' + ) + elif isinstance(plug, PrecisionPlugin): + if precision is None: + precision = plug + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin.' + f' Found more than 1 precision plugin: {type(plug).__name__}' + ) + + elif isinstance(plug, ClusterEnvironment): + if cluster_environment is None: + cluster_environment = plug + else: + raise MisconfigurationException( + 'You can only specify one cluster environment. Found more than 1 cluster environment plugin' + ) + else: + raise MisconfigurationException( + f'Found invalid type for plugin {plug}. Expected a precision or training type plugin.' + ) + + self._training_type_plugin = training_type + self._precision_plugin = precision + self._cluster_environment = cluster_environment or self.select_cluster_environment() + + @property + def precision_plugin(self) -> PrecisionPlugin: + if self._precision_plugin is None: + self._precision_plugin = self.select_precision_plugin() + return self._precision_plugin + + @property + def training_type_plugin(self) -> TrainingTypePlugin: + if self._training_type_plugin is None: + self._training_type_plugin = self.select_training_type_plugin() + else: + self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + + return self._training_type_plugin + + @property + def cluster_environment(self) -> ClusterEnvironment: + return self._cluster_environment + + @property + def on_cpu(self) -> bool: + return self._device_type == DeviceType.CPU + + @property + def on_tpu(self) -> bool: + return self.tpu_cores is not None + + @property + def tpu_id(self) -> Optional[int]: + if self.on_tpu and isinstance(self.tpu_cores, list): + return self.tpu_cores[0] + + return None + + @property + def on_gpu(self) -> bool: + gpus = self.parallel_device_ids + return gpus is not None and len(gpus) > 0 and torch.cuda.is_available() + + @property + def use_dp(self) -> bool: + return self._distrib_type == DistributedType.DP + + @property + def use_ddp(self) -> bool: + return self._distrib_type in ( + DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, + DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED + ) + + @property + def use_ddp2(self) -> bool: + return self._distrib_type == DistributedType.DDP2 + + @property + def use_horovod(self) -> bool: + return self._distrib_type == DistributedType.HOROVOD + + @property + def use_deepspeed(self) -> bool: + return self._distrib_type == DistributedType.DEEPSPEED + + @property + def is_distributed(self) -> bool: + # Used for custom plugins. + # Custom plugins should implement is_distributed property. + if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: + return self.training_type_plugin.is_distributed + is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod + if self.on_tpu: + is_distributed |= self.training_type_plugin.is_distributed + return is_distributed + + @property + def num_gpus(self) -> int: + gpus = self.parallel_device_ids + if gpus is None: + return 0 + return len(gpus) + + @property + def parallel_devices(self) -> List[Union[torch.device, int]]: + if self.on_gpu: + devices = [torch.device("cuda", i) for i in self.parallel_device_ids] + elif self.on_tpu: + # explicitly don't make a tpu device here! + # https://github.com/PyTorchLightning/pytorch-lightning/issues/3169 + devices = [i for i in self.parallel_device_ids] + else: + devices = [torch.device("cpu")] * self.num_processes + return devices + + @property + def root_gpu(self) -> Optional[int]: + return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None + + @property + def is_using_torchelastic(self) -> bool: + te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) + return te_flags_passed + + def select_precision_plugin(self) -> PrecisionPlugin: + # set precision type + self.amp_type = AMPType.from_str(self.amp_type) + + if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): + return DeepSpeedPrecisionPlugin(self.precision) + + if self.precision == 32: + return PrecisionPlugin() + elif self.precision == 64: + return DoublePrecisionPlugin() + elif self.precision == 16: + if self.on_tpu: + return TPUHalfPrecisionPlugin() + + if self.amp_type == AMPType.NATIVE: + if self.on_cpu: + raise MisconfigurationException( + "You have asked for native AMP on CPU, but AMP is only available on GPU." + ) + elif not _NATIVE_AMP_AVAILABLE: + msg = "You have asked for native AMP but your PyTorch version does not support it." \ + " Consider upgrading with `pip install torch>=1.6`." + if _APEX_AVAILABLE: + self.amp_type = AMPType.APEX + msg += " We will attempt to use NVIDIA Apex for this session." + rank_zero_warn(msg) + else: + raise MisconfigurationException(msg) + else: + log.info("Using native 16bit precision.") + if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + return ShardedNativeMixedPrecisionPlugin() + return NativeMixedPrecisionPlugin() + + if self.amp_type == AMPType.APEX: + if not _APEX_AVAILABLE: + raise MisconfigurationException( + "You have asked for Apex AMP but you have not installed it yet." + " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" + ) + if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + raise MisconfigurationException( + "Sharded Plugin is not supported with Apex AMP," + " please using native AMP for 16-bit precision." + ) + log.info("Using APEX 16bit precision.") + return ApexMixedPrecisionPlugin(self.amp_level) + + raise NotImplementedError("We only support precisions 64, 32 and 16!") + + def select_training_type_plugin(self) -> TrainingTypePlugin: + if self.use_ddp2: + plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) + elif self.use_ddp and self.use_deepspeed: + plugin = DeepSpeedPlugin( + num_nodes=self.num_nodes, + cluster_environment=self.select_cluster_environment(), + parallel_devices=self.parallel_devices + ) + elif self.use_ddp: + use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks + use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic + use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN + use_ddp_cpu_spawn = self.use_ddp and self.on_cpu + use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic + use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks + use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED + use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN + + # TODO: decouple from TE + # ddp script mode uses the same flags as TE + if os.environ.get("PL_IN_DDP_SUBPROCESS", False): + use_torchelastic_ddp = False + + if self.on_tpu: + ddp_plugin_cls = TPUSpawnPlugin + elif use_ddp_sharded: + ddp_plugin_cls = DDPShardedPlugin + elif use_ddp_sharded_spawn: + ddp_plugin_cls = DDPSpawnShardedPlugin + elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: + ddp_plugin_cls = DDPPlugin + elif use_ddp_spawn or use_ddp_cpu_spawn: + ddp_plugin_cls = DDPSpawnPlugin + else: + ddp_plugin_cls = DDPPlugin + + plugin = ddp_plugin_cls( + parallel_devices=self.parallel_devices, + num_nodes=self.num_nodes, + cluster_environment=self.cluster_environment, + sync_batchnorm=self.sync_batchnorm, + ) + elif self.use_dp: + plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) + elif self.use_horovod: + plugin = HorovodPlugin(parallel_devices=self.parallel_devices) + elif self.on_tpu: + if isinstance(self.tpu_cores, list): + plugin = SingleTPUPlugin(self.tpu_id) + else: + plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores))) + else: + single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) + plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) + return plugin + + def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: + # necessary for when the user has passed in a plugin + if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + training_type.parallel_devices = self.parallel_devices + if hasattr(training_type, 'num_processes'): + training_type.num_processes = len(self.parallel_devices) + + if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: + training_type.cluster_environment = self.select_cluster_environment() + + if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: + training_type.num_nodes = self.num_nodes + + # Automatically set sync_batchnorm if None. + # Useful for custom plugins. + if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None: + training_type.sync_batchnorm = self.sync_batchnorm + + return training_type + + def select_accelerator(self) -> Accelerator: + if isinstance(self.distributed_backend, Accelerator): + # custom accelerator from user + if self._precision_plugin is not None or self._training_type_plugin is not None: + # plugins also specified by user + rank_zero_warn( + 'Specified `Precision` and `TrainingType` plugins will be ignored,' + ' since an `Accelerator` instance was provided.' + ) + return self.distributed_backend + + if self.on_gpu: + acc_cls = GPUAccelerator + elif self.on_tpu: + acc_cls = TPUAccelerator + else: + acc_cls = CPUAccelerator + + return acc_cls( + precision_plugin=self.precision_plugin, + training_type_plugin=self.training_type_plugin, + ) + + def select_cluster_environment(self) -> ClusterEnvironment: + if self._cluster_environment is not None: + return self._cluster_environment + if self.is_slurm_managing_tasks: + env = SLURMEnvironment() + elif self.is_using_torchelastic: + env = TorchElasticEnvironment() + else: + env = LightningEnvironment() + return env + + def set_distributed_mode(self, distributed_backend: Optional[str] = None): + + if distributed_backend is not None: + self.distributed_backend = distributed_backend + + if isinstance(self.distributed_backend, Accelerator): + return + + if self.distributed_backend is None: + if self.has_horovodrun(): + self._set_horovod_backend() + elif self.num_gpus == 0 and (self.num_nodes > 1 or self.num_processes > 1): + self._distrib_type = DistributedType.DDP + elif self.num_gpus > 1: + rank_zero_warn( + 'You requested multiple GPUs but did not specify a backend, e.g.' + ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' + ) + self.distributed_backend = "ddp_spawn" + + # special case with DDP on CPUs + if self.distributed_backend == "ddp_cpu": + self._distrib_type = DistributedType.DDP + if self.num_gpus > 0: + rank_zero_warn( + 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' + ) + self.parallel_device_ids = None + if self.num_processes is None: + # define the max CPU available + self.num_processes = os.cpu_count() + # special case with TPUs + elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: + self._device_type = DeviceType.TPU + elif self.distributed_backend and self._distrib_type is None: + self._distrib_type = DistributedType(self.distributed_backend) + + # unless you request explicitly for CPU and some GPU are available use them + _on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend + if self.num_gpus > 0 and not _on_cpu: + self._device_type = DeviceType.GPU + + _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + # DP and DDP2 cannot run without GPU + if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu: + rank_zero_warn( + 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' + ) + # todo: in some cases it yield in comarison None and int + if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): + self._distrib_type = DistributedType.DDP + else: + rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') + self._distrib_type = None + + # finished configuring self._distrib_type, check ipython environment + self.check_interactive_compatibility() + + # for DDP overwrite nb processes by requested GPUs + if ( + self._device_type == DeviceType.GPU + and self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + ): + self.num_processes = self.num_gpus + + if (self._device_type == DeviceType.GPU and self._distrib_type == DistributedType.DDP2): + self.num_processes = self.num_nodes + + # Horovod is an extra case... + if self.distributed_backend == "horovod": + self._set_horovod_backend() + + using_valid_distributed = self.use_ddp or self.use_ddp2 + if self.num_nodes > 1 and not using_valid_distributed: + # throw error to force user to choose a supported distributed type such as ddp or ddp2 + raise MisconfigurationException( + 'Your chosen distributed type does not support num_nodes > 1. ' + 'Please set accelerator=ddp or accelerator=ddp2.' + ) + + rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') + num_cores = self.tpu_cores if self.tpu_cores is not None else 0 + rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') + + if torch.cuda.is_available() and self._device_type != DeviceType.GPU: + rank_zero_warn( + "GPU available but not used. Set the gpus flag in your trainer" + " `Trainer(gpus=1)` or script `--gpus=1`." + ) + + def _set_horovod_backend(self): + self.check_horovod() + self._distrib_type = DistributedType.HOROVOD + + # Initialize Horovod to get rank / size info + hvd.init() + if self.on_gpu: + # Horovod assigns one local GPU per process + self.parallel_device_ids = list(range(hvd.local_size())) + else: + self.num_processes = hvd.local_size() + + def check_interactive_compatibility(self): + """ + Raises a `MisconfigurationException` if the accelerator and/or plugin + is not compatible with an interactive environment + """ + from pytorch_lightning.utilities import _IS_INTERACTIVE + if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible(): + raise MisconfigurationException( + f"Selected distributed backend {self._distrib_type} is not compatible with an interactive" + " environment. Run your code as a script, or choose one of the compatible backends:" + f" {', '.join(DistributedType.interactive_compatible_types())}" + ) + + def check_horovod(self): + """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" + if not _HOROVOD_AVAILABLE: + raise MisconfigurationException( + 'Requested `distributed_backend="horovod"`, but Horovod is not installed.' + "Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]" + ) + + if self.num_gpus > 1 or self.num_nodes > 1: + raise MisconfigurationException( + "Horovod does not support setting num_nodes / num_gpus explicitly. Use " + "horovodrun / mpirun to configure the number of processes." + ) + + @staticmethod + def has_horovodrun() -> bool: + """Returns True if running with `horovodrun` using Gloo or OpenMPI.""" + return "OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ + + def configure_slurm_ddp(self): + # extract SLURM flag vars + # whenever we have the correct number of tasks, we let slurm manage processes + # otherwise we launch the required number of processes + if self.use_ddp or self.use_ddp2: + num_requested_gpus = self.num_gpus * self.num_nodes + num_slurm_tasks = 0 + try: + num_slurm_tasks = int(os.environ["SLURM_NTASKS"]) + self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus + + # enable slurm cpu + if num_requested_gpus == 0: + self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes + + # in interactive mode we don't manage tasks + job_name = os.environ["SLURM_JOB_NAME"] + if job_name == "bash": + self.is_slurm_managing_tasks = False + + except Exception: + # likely not on slurm, so set the slurm managed flag to false + self.is_slurm_managing_tasks = False + + # used for tests only, set this flag to simulate slurm managing a task + try: + should_fake = int(os.environ["FAKE_SLURM_MANAGING_TASKS"]) + if should_fake: + self.is_slurm_managing_tasks = True + except Exception: + pass + + # notify user the that slurm is managing tasks + if self.is_slurm_managing_tasks: + rank_zero_info("Multi-processing is handled by Slurm.") diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py new file mode 100644 index 00000000000000..8a5289e608c945 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -0,0 +1,165 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import List, Union + +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class CallbackConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init( + self, + callbacks, + checkpoint_callback, + progress_bar_refresh_rate, + process_position, + default_root_dir, + weights_save_path, + resume_from_checkpoint, + stochastic_weight_avg, + ): + self.trainer.resume_from_checkpoint = resume_from_checkpoint + + # init folder paths for checkpoint + weights save callbacks + self.trainer._default_root_dir = default_root_dir or os.getcwd() + self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir + self.trainer._stochastic_weight_avg = stochastic_weight_avg + + # init callbacks + if isinstance(callbacks, Callback): + callbacks = [callbacks] + self.trainer.callbacks = callbacks or [] + + # configure checkpoint callback + # pass through the required args to figure out defaults + self.configure_checkpoint_callbacks(checkpoint_callback) + + # configure swa callback + self._configure_swa_callbacks() + + # init progress bar + self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) + + # push all checkpoint callbacks to the end + # it is important that these are the last callbacks to run + self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks) + + def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): + if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False: + raise MisconfigurationException( + "Trainer was configured with checkpoint_callback=False but found ModelCheckpoint" + " in callbacks list." + ) + + if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: + self.trainer.callbacks.append(ModelCheckpoint()) + + def _configure_swa_callbacks(self): + if not self.trainer._stochastic_weight_avg: + return + + from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging + existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)] + if not existing_swa: + self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks + + def configure_progress_bar(self, refresh_rate=None, process_position=0): + if os.getenv('COLAB_GPU') and refresh_rate is None: + # smaller refresh rate on colab causes crashes, choose a higher value + refresh_rate = 20 + refresh_rate = 1 if refresh_rate is None else refresh_rate + + progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] + if len(progress_bars) > 1: + raise MisconfigurationException( + 'You added multiple progress bar callbacks to the Trainer, but currently only one' + ' progress bar is supported.' + ) + elif len(progress_bars) == 1: + progress_bar_callback = progress_bars[0] + elif refresh_rate > 0: + progress_bar_callback = ProgressBar( + refresh_rate=refresh_rate, + process_position=process_position, + ) + self.trainer.callbacks.append(progress_bar_callback) + else: + progress_bar_callback = None + + return progress_bar_callback + + def _trainer_has_checkpoint_callbacks(self): + return len(self.trainer.checkpoint_callbacks) > 0 + + def attach_model_logging_functions(self, model): + for callback in self.trainer.callbacks: + callback.log = model.log + callback.log_dict = model.log_dict + + @staticmethod + def _attach_model_callbacks(model: LightningModule, trainer) -> None: + """ + Attaches the callbacks defined in the model. + If a callback returned by the model's configure_callback method has the same type as one or several + callbacks already present in the trainer callbacks list, it will replace them. + In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks + will be pushed to the end of the list, ensuring they run last. + + Args: + model: A model which may or may not define new callbacks in + :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`. + trainer: The trainer on which the callbacks get attached/merged. + """ + model_callbacks = model.configure_callbacks() + if not model_callbacks: + return + model_callback_types = set(type(c) for c in model_callbacks) + trainer_callback_types = set(type(c) for c in trainer.callbacks) + override_types = model_callback_types.intersection(trainer_callback_types) + if override_types: + rank_zero_info( + "The following callbacks returned in `LightningModule.configure_callbacks` will override" + " existing callbacks passed to Trainer:" + f" {', '.join(sorted(t.__name__ for t in override_types))}" + ) + # remove all callbacks with a type that occurs in model callbacks + all_callbacks = [c for c in trainer.callbacks if type(c) not in override_types] + all_callbacks.extend(model_callbacks) + all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks) + # TODO: connectors refactor: move callbacks list to connector and do not write Trainer state + trainer.callbacks = all_callbacks + + @staticmethod + def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: + """ + Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of + checkpoint callbacks is preserved, as well as the order of all other callbacks. + + Args: + callbacks: A list of callbacks. + + Return: + A new list in which the last elements are ModelCheckpoints if there were any present in the + input. + """ + checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)] + not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)] + return not_checkpoints + checkpoints diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py new file mode 100644 index 00000000000000..8b602fa6caa695 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -0,0 +1,397 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from pathlib import Path +from typing import Optional, Union + +import torch + +import pytorch_lightning +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import ( + _APEX_AVAILABLE, + _OMEGACONF_AVAILABLE, + AMPType, + DeviceType, + rank_zero_info, + rank_zero_warn, +) +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS + +if _APEX_AVAILABLE: + from apex import amp + +if _OMEGACONF_AVAILABLE: + from omegaconf import Container + + +class CheckpointConnector: + + def __init__(self, trainer): + self.trainer = trainer + + # used to validate checkpointing logic + self.has_trained = False + + def restore_weights(self) -> None: + """ + Attempt to restore a checkpoint (e.g. weights) in this priority: + 1. from HPC weights + 2. from `resume_from_checkpoint` file + 3. don't restore + """ + # clear cache before restore + if self.trainer._device_type == DeviceType.GPU: + torch.cuda.empty_cache() + + # 1. Attempt to restore states from HPC checkpoint + dir_path_hpc = str(self.trainer.weights_save_path) + max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") + if max_suffix is not None: + checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' + self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU) + rank_zero_info(f'restored hpc model from: {checkpoint_path}') + + # 2. Attempt to restore states from `resume_from_checkpoint` file + elif self.trainer.resume_from_checkpoint is not None: + self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) + + # wait for all to catch up + self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights') + + # clear cache after restore + if self.trainer._device_type == DeviceType.GPU: + torch.cuda.empty_cache() + + def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: + """ + Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. + All restored states are listed in return value description of `dump_checkpoint`. + """ + # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. + fs = get_filesystem(checkpoint_path) + if not fs.exists(checkpoint_path): + rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") + return False + + # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # acquire the model + model = self.trainer.lightning_module + + # restore model and datamodule state + self.restore_model_state(model, checkpoint) + + if on_gpu: + model.cuda(self.trainer.root_gpu) + + # restore training state + self.restore_training_state(checkpoint) + + rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") + return True + + def restore_model_state(self, model: LightningModule, checkpoint) -> None: + """ + Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object + """ + + # restore datamodule states + if self.trainer.datamodule is not None: + self.trainer.datamodule.on_load_checkpoint(checkpoint) + + # hook: give user access to checkpoint if needed. + model.on_load_checkpoint(checkpoint) + + # restore model state_dict + model.load_state_dict(checkpoint['state_dict']) + + def restore_training_state(self, checkpoint): + """ + Restore trainer state. + Model will get its change to update + :param checkpoint: + :return: + """ + # validation + if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint: + raise KeyError( + 'Trying to restore training state but checkpoint contains only the model.' + ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' + ) + + if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]): + raise ValueError( + "The checkpoint you're attempting to load follows an" + " outdated schema. You can upgrade to the current schema by running" + " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" + " where `model.ckpt` is your checkpoint file." + ) + + # restore amp scaling + if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: + self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: + amp.load_state_dict(checkpoint['amp_scaling_state']) + + # restore callback states + self.trainer.on_load_checkpoint(checkpoint) + + self.trainer.global_step = checkpoint['global_step'] + self.trainer.current_epoch = checkpoint['epoch'] + + # crash if max_epochs is lower then the current epoch from the checkpoint + if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs: + m = f""" + you restored a checkpoint with current_epoch={self.trainer.current_epoch} + but the Trainer(max_epochs={self.trainer.max_epochs}) + """ + raise MisconfigurationException(m) + + # Division deals with global step stepping once per accumulated batch + # Inequality deals with different global step for odd vs even num_training_batches + n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches + expected_steps = self.trainer.num_training_batches / n_accum + if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1: + rank_zero_warn( + "You're resuming from a checkpoint that ended mid-epoch." + " Training will start from the beginning of the next epoch." + " This can cause unreliable results if further training is done," + " consider using an end of epoch checkpoint." + ) + + # restore the optimizers + optimizer_states = checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + + # move optimizer to GPU 1 weight at a time + # avoids OOM + if self.trainer.root_gpu is not None: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.trainer.root_gpu) + + # restore the lr schedulers + lr_schedulers = checkpoint['lr_schedulers'] + for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): + scheduler['scheduler'].load_state_dict(lrs_state) + + # ---------------------------------- + # PRIVATE OPS + # ---------------------------------- + def hpc_save(self, folderpath: str, logger): + # make sure the checkpoint folder exists + folderpath = str(folderpath) # because the tests pass a path object + fs = get_filesystem(folderpath) + fs.makedirs(folderpath, exist_ok=True) + + # save logger to make sure we get all the metrics + logger.save() + + max_suffix = self.max_ckpt_in_folder(folderpath) + ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 + + fs.makedirs(folderpath, exist_ok=True) + filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') + + # give model a chance to do something on hpc_save + model = self.trainer.lightning_module + checkpoint = self.dump_checkpoint() + + model.on_hpc_save(checkpoint) + + checkpoint = self.trainer.accelerator.on_save(checkpoint) + + # do the actual save + # TODO: fix for anything with multiprocess DP, DDP, DDP2 + try: + atomic_save(checkpoint, filepath) + except AttributeError as err: + if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + rank_zero_warn( + 'warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' + ) + atomic_save(checkpoint, filepath) + + return filepath + + def dump_checkpoint(self, weights_only: bool = False) -> dict: + """Creating a model checkpoint dictionary object from various component states. + + Args: + weights_only: saving model weights only + + Return: + structured dictionary: { + 'epoch': training epoch + 'global_step': training global step + 'pytorch-lightning_version': PyTorch Lightning's version + 'callbacks': "callback specific state"[] # if not weights_only + 'optimizer_states': "PT optim's state_dict"[] # if not weights_only + 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only + 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp + 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp + 'state_dict': Model's state_dict (e.g. network weights) + CHECKPOINT_HYPER_PARAMS_NAME: + CHECKPOINT_HYPER_PARAMS_KEY: + CHECKPOINT_HYPER_PARAMS_TYPE: + something_cool_i_want_to_save: anything you define through model.on_save_checkpoint + LightningDataModule.__class__.__name__: pl DataModule's state + } + """ + + # dump epoch/global_step/pytorch-lightning_version + current_epoch = self.trainer.current_epoch + global_step = self.trainer.global_step + has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step + + global_step += 1 + if not has_reached_max_steps: + current_epoch += 1 + + model = self.trainer.lightning_module + + checkpoint = { + 'epoch': current_epoch, + 'global_step': global_step, + 'pytorch-lightning_version': pytorch_lightning.__version__, + 'state_dict': model.state_dict(), + } + + if not weights_only: + # dump callbacks + checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint) + + optimizer_states = [] + for i, optimizer in enumerate(self.trainer.optimizers): + # Rely on accelerator to dump optimizer state + optimizer_state = self.trainer.accelerator.optimizer_state(optimizer) + optimizer_states.append(optimizer_state) + + checkpoint['optimizer_states'] = optimizer_states + + # dump lr schedulers + lr_schedulers = [] + for scheduler in self.trainer.lr_schedulers: + lr_schedulers.append(scheduler['scheduler'].state_dict()) + checkpoint['lr_schedulers'] = lr_schedulers + + # dump amp scaling + if ( + self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU + and self.trainer.scaler is not None + ): + checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() + elif self.trainer.amp_backend == AMPType.APEX: + checkpoint['amp_scaling_state'] = amp.state_dict() + + # dump hyper-parameters + if model.hparams: + if hasattr(model, '_hparams_name'): + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name + # dump arguments + if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + else: + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + + # give the model a chance to dump a few things + model.on_save_checkpoint(checkpoint) + if self.trainer.datamodule is not None: + self.trainer.datamodule.on_save_checkpoint(checkpoint) + + return checkpoint + + def hpc_load(self, checkpoint_path: str, on_gpu: bool): + """ + Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. + All restored states are listed in return value description of `dump_checkpoint`. + """ + + # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # acquire the model + model = self.trainer.lightning_module + + # restore model and datamodule state + self.restore_model_state(model, checkpoint) + + if self.trainer.root_gpu is not None: + model.cuda(self.trainer.root_gpu) + + # restore training state + self.restore_training_state(checkpoint) + + # call hpc specific hook + model.on_hpc_load(checkpoint) + + def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: + """List up files in `dir_path` with `name_key`, then yield maximum suffix number. + + Args: + dir_path: path of directory which may contain files whose name include `name_key` + name_key: file name prefix + + Returns: + None if no-corresponding-file else maximum suffix number + """ + + # check directory existence + fs = get_filesystem(dir_path) + if not fs.exists(dir_path): + return None + + # check corresponding file existence + files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)] + files = [x for x in files if name_key in x] + if len(files) == 0: + return None + + # extract suffix number + ckpt_vs = [] + for name in files: + name = name.split(name_key)[-1] + name = re.sub('[^0-9]', '', name) + ckpt_vs.append(int(name)) + + return max(ckpt_vs) + + def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: + """Get path of maximum-epoch checkpoint in the folder.""" + + max_suffix = self.max_ckpt_in_folder(folder_path) + ckpt_number = max_suffix if max_suffix is not None else 0 + return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt' + + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + filepath: write-target file's path + weights_only: saving model weights only + """ + _checkpoint = self.dump_checkpoint(weights_only) + self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py new file mode 100644 index 00000000000000..5d2f141dc64a83 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -0,0 +1,177 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from torch.utils.data import DataLoader + +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden + + +class DataConnector(object): + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init( + self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool + ) -> None: + self.trainer.datamodule = None + self.trainer.prepare_data_per_node = prepare_data_per_node + + if not isinstance(check_val_every_n_epoch, int): + raise MisconfigurationException( + f"check_val_every_n_epoch should be an integer. Found {check_val_every_n_epoch}" + ) + + self.trainer.check_val_every_n_epoch = check_val_every_n_epoch + self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch + self.trainer._is_data_prepared = False + + def get_profiled_train_dataloader(self, train_dataloader): + profiled_dl = self.trainer.profiler.profile_iterable( + enumerate(self._with_is_last(train_dataloader)), "get_train_batch" + ) + return profiled_dl + + def _with_is_last(self, iterable): + """Pass through values from the given iterable with an added boolean indicating if this is the last item. + See `https://stackoverflow.com/a/1630350 `_""" + it = iter(iterable) + last = next(it) + for val in it: + # yield last and has next + yield last, False + last = val + # yield last, no longer has next + yield last, True + + def prepare_data(self, model): + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 + # or in the case where each node needs to do its own manipulation in which case just local_rank=0 + if self.can_prepare_data(): + if self.trainer.datamodule is not None: + self.trainer.datamodule.prepare_data() + model.prepare_data() + self.trainer._is_data_prepared = True + + def can_prepare_data(self): + should_call_dm_prepare_data = True + if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule): + should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data + + if self.trainer.prepare_data_per_node: + return self.trainer.local_rank == 0 and should_call_dm_prepare_data + else: + return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data + + def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + + self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) + + # set up the passed in dataloaders (if needed) + self.attach_dataloaders(model, train_dataloader, val_dataloaders) + self.attach_datamodule(model, datamodule) + self._validate_data_hooks(model) + + def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' + ) + + def _validate_data_hooks(self, model): + # Raise Misconfiguration exception since these hooks are not supported in DP mode + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') + + def attach_dataloaders( + self, + model, + train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ): + # when dataloader is passed via fit, patch the train_dataloader + # functions to overwrite with these implementations + if train_dataloader is not None: + model.train_dataloader = _PatchDataLoader(train_dataloader) + + if val_dataloaders is not None: + model.val_dataloader = _PatchDataLoader(val_dataloaders) + + if test_dataloaders is not None: + model.test_dataloader = _PatchDataLoader(test_dataloaders) + + if predict_dataloaders is not None: + model.predict_dataloader = _PatchDataLoader(predict_dataloaders) + + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: + # We use datamodule if it's been provided, otherwise we check model for it + datamodule = datamodule or getattr(model, 'datamodule', None) + + # If we have a datamodule, attach necessary hooks + dataloaders + if datamodule: + + # Override loader hooks + dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader') + for method in dl_methods: + if is_overridden(method, datamodule): + setattr(model, method, getattr(datamodule, method)) + + # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if is_overridden(hook, datamodule): + setattr(model, hook, getattr(datamodule, hook)) + + self.trainer.datamodule = datamodule + datamodule.trainer = self.trainer + + # experimental feature for Flash + if hasattr(datamodule, "data_pipeline"): + model.data_pipeline = datamodule.data_pipeline + + +class _PatchDataLoader(object): + r""" + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + + """ + + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + self.dataloader = dataloader + + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + self.patch_loader_code = str(self.__call__.__code__) + + def __call__(self) -> Union[List[DataLoader], DataLoader]: + return self.dataloader diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py new file mode 100644 index 00000000000000..28c99f8f4de6d1 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -0,0 +1,97 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class DebuggingConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_init_start( + self, + limit_train_batches, + limit_val_batches, + limit_test_batches, + limit_predict_batches, + val_check_interval, + overfit_batches, + fast_dev_run, + ): + if not isinstance(fast_dev_run, (bool, int)): + raise MisconfigurationException( + f'fast_dev_run={fast_dev_run} is not a valid configuration.' + ' It should be either a bool or an int >= 0' + ) + + if isinstance(fast_dev_run, int) and (fast_dev_run < 0): + raise MisconfigurationException( + f'fast_dev_run={fast_dev_run} is not a' + ' valid configuration. It should be >= 0.' + ) + + self.trainer.fast_dev_run = fast_dev_run + fast_dev_run = int(fast_dev_run) + + # set fast_dev_run=True when it is 1, used while logging + if fast_dev_run == 1: + self.trainer.fast_dev_run = True + + if fast_dev_run: + limit_train_batches = fast_dev_run + limit_val_batches = fast_dev_run + limit_test_batches = fast_dev_run + limit_predict_batches = fast_dev_run + self.trainer.max_steps = fast_dev_run + self.trainer.num_sanity_val_steps = 0 + self.trainer.max_epochs = 1 + val_check_interval = 1.0 + self.trainer.check_val_every_n_epoch = 1 + self.trainer.logger = DummyLogger() + + rank_zero_info( + 'Running in fast_dev_run mode: will run a full train,' + f' val and test loop using {fast_dev_run} batch(es).' + ) + + self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') + self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches') + self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches') + self.trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, 'limit_predict_batches') + self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval') + self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') + self.determine_data_use_amount(self.trainer.overfit_batches) + + def determine_data_use_amount(self, overfit_batches: float) -> None: + """Use less data for debugging purposes""" + if overfit_batches > 0: + self.trainer.limit_train_batches = overfit_batches + self.trainer.limit_val_batches = overfit_batches + self.trainer.limit_test_batches = overfit_batches + + +def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: + if 0 <= batches <= 1: + return batches + elif batches > 1 and batches % 1.0 == 0: + return int(batches) + else: + raise MisconfigurationException( + f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.' + ) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py new file mode 100644 index 00000000000000..1f1c41c6eb2f0e --- /dev/null +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +from typing import Callable + +from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables + + +def _defaults_from_env_vars(fn: Callable) -> Callable: + """ + Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which + input arguments should be moved automatically to the correct device. + """ + + @wraps(fn) + def insert_env_defaults(self, *args, **kwargs): + cls = self.__class__ # get the class + if args: # inace any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] + # convert args to kwargs + kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + env_variables = vars(parse_env_variables(cls)) + # update the kwargs by env variables + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) + + # all args were already moved to kwargs + return fn(self, **kwargs) + + return insert_env_defaults diff --git a/pytorch_lightning/trainer/connectors/logger_connector/__init__.py b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py new file mode 100644 index 00000000000000..f14e20f2325333 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector # noqa: F401 diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py new file mode 100644 index 00000000000000..87b730403b5513 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -0,0 +1,231 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class CallbackHookNameValidator: + + @staticmethod + def check_logging_in_callbacks( + current_hook_fx_name: str = None, on_step: bool = None, on_epoch: bool = None + ) -> None: + if current_hook_fx_name is None: + return + + internal_func = getattr(CallbackHookNameValidator, f"_{current_hook_fx_name}_log", None) + + if internal_func is None: + return + + current_callback_hook_auth_args = internal_func() + + if current_callback_hook_auth_args is not None: + m = "{} function supports only {} in {}. Provided {}" + if on_step not in current_callback_hook_auth_args["on_step"]: + msg = m.format(current_hook_fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step) + raise MisconfigurationException(msg) + + if on_epoch not in current_callback_hook_auth_args["on_epoch"]: + msg = m.format(current_hook_fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch) + raise MisconfigurationException(msg) + else: + raise MisconfigurationException( + f"{current_hook_fx_name} function doesn't support logging using self.log() yet." + ) + + @staticmethod + def _on_before_accelerator_backend_setup_log(): + """Called before accelerator is being setup""" + return None + + @staticmethod + def _setup_log(): + """Called when fit or test begins""" + return None + + @staticmethod + def _on_configure_sharded_model_log(): + """Called before configure sharded model""" + return None + + @staticmethod + def _teardown_log(): + """Called at the end of fit and test""" + return None + + @staticmethod + def _on_init_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_init_end_log(): + """Called when the trainer initialization ends, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_start_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_fit_end_log(): + """Called when the trainer initialization begins, model has not yet been set.""" + return None + + @staticmethod + def _on_sanity_check_start_log(): + """Called when the validation sanity check starts.""" + return None + + @staticmethod + def _on_sanity_check_end_log(): + """Called when the validation sanity check ends.""" + return None + + @staticmethod + def _on_train_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_start_log(): + """Called when the epoch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_epoch_end_log(): + """Called when the epoch ends.""" + return {"on_step": [False], "on_epoch": [False, True]} + + @staticmethod + def _on_train_start_log(): + """Called when the train begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_pretrain_routine_start_log(): + """Called when the train begins.""" + return None + + @staticmethod + def _on_pretrain_routine_end_log(): + """Called when the train ends.""" + return None + + @staticmethod + def _on_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_start_log(): + """Called when the training batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_train_batch_end_log(): + """Called when the training batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_start_log(): + """Called when the validation batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_batch_end_log(): + """Called when the validation batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_start_log(): + """Called when the test batch begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_batch_end_log(): + """Called when the test batch ends.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_start_log(): + """Called when the validation loop begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_validation_end_log(): + """Called when the validation loop ends.""" + return None + + @staticmethod + def _on_test_start_log(): + """Called when the test begins.""" + return {"on_step": [False, True], "on_epoch": [False, True]} + + @staticmethod + def _on_test_end_log(): + """Called when the test ends.""" + return None + + @staticmethod + def _on_keyboard_interrupt_log(): + """Called when the training is interrupted by KeyboardInterrupt.""" + return None + + @staticmethod + def _on_save_checkpoint_log(): + """Called when saving a model checkpoint.""" + return None + + @staticmethod + def _on_load_checkpoint_log(): + """Called when loading a model checkpoint.""" + return None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py new file mode 100644 index 00000000000000..e2ce66c86ecffd --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -0,0 +1,525 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple +from weakref import proxy + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import DistributedType, LightningEnum +from pytorch_lightning.utilities.warnings import WarningCache + +log = logging.getLogger(__name__) + + +class MetricWarningCache(WarningCache): + + def __init__(self): + super().__init__() + self.warned_metrics = [] + + +warning_cache = MetricWarningCache() + + +class ResultStoreType(LightningEnum): + INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" + OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" + + +class HookResultStore: + """ + This class is defined for internal usage. + It holds all metrics logged using the self.log function + in the scope of ModelHooks or Callback functions. + + We need to differentiate 3 different scenarios: + - (1): We are outside of a batch loop + * It means no dataloader_idx, no optimizer idx, etc.. + - (2): We are inside the training batch loop + * We have an optimizer idx and split idx to track + - (3): We are inside the evaluation loop + * We have a dataloader_idx to track + + The data store `Result` objects for those 3 scenarios in `self._internals`. + + (1): self._internals = {dataloader_idx: [Result(), ..., Result()]} + * dataloader_idx not being defined, it is set to 0 b default + (2): self._internals = {dataloader_idx: {optimizer_idx: {batch_idx: [Result(), ..., Result()]}}} + (3): Same as (1) for simplicity + + Those data structures enables us to reduce properly Result object when batch loop is finished. + """ + + def __init__(self, fx_name: str, all_gather_fn: Callable, should_warn: bool) -> None: + self._fx_name = fx_name + self._all_gather_fn = all_gather_fn + self._should_warn = should_warn + self._internals = {} + self._internals_reduced = {} + self._internal_type = None + self.has_reduced = False + self._latest_ref = {} + + @property + def has_several_dataloaders(self) -> bool: + return self.num_dataloaders > 1 + + @property + def num_dataloaders(self) -> int: + inter = self._internals_reduced if self.has_reduced else self._internals + return len(inter) + + def check_dataloader_idx(self, result: Result) -> bool: + random_key = list(result.keys())[-1] + return result["meta"][random_key]["dataloader_idx"] is not None + + def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict: + results = {} + for opt_idx in latest_result_opt: + latest_result = latest_result_opt[opt_idx] + add_dataloader_idx = self.check_dataloader_idx(latest_result) + func = getattr(latest_result, func_name) + results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) + return results + + def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + """ + This function used cache_ref and cache_result to optimize loading metrics + + Context: As we update the logger_connector metrics on every `self.log` call, + and it can be pretty time consuming, especially when logging outside batch loop. + + HookResultStore keeps track of its latest added result object, + and cache its pbar and log metrics if already called on, + """ + return [ + self.get_latest_from_func_name(self._latest_ref[dl_idx], func_name, *args, **kwargs) + for dl_idx in range(self.num_dataloaders) + ] + + def get_batch_pbar_metrics(self, *args, **kwargs): + return self.run_latest_batch_metrics_with_func_name("get_batch_pbar_metrics", *args, **kwargs) + + def get_batch_log_metrics(self, *args, **kwargs): + return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics", *args, **kwargs) + + def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: + if not isinstance(opt_metric, Result): + raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + + func = getattr(opt_metric, func_name) + metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) + if self._should_warn: + for non_metric_key in opt_metric.get_non_metrics_keys(): + if non_metric_key in metrics_to_log and non_metric_key not in warning_cache.warned_metrics: + metric = self._all_gather_fn(metrics_to_log[non_metric_key]) + if any(metric[0] != m for m in metric[1:]): + warning_cache.warn( + f"The value associated to the key {non_metric_key}: {metric.cpu().tolist()} " + "doesn't appear to be the same accross all processes. " + "HINT: One could either do: `self.log(..., sync_dist=True, sync_fn=torch.mean)`" + " to force mean reduction across processes which can be inaccurate or implement" + " a `torchmetrics.Metric`" + ) + warning_cache.warned_metrics.append(non_metric_key) + + results.append(metrics_to_log) + + def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + results = [] + for dl_idx in range(self.num_dataloaders): + opt_metrics = self._internals_reduced[dl_idx] + if isinstance(opt_metrics, defaultdict): + for opt_metric in opt_metrics.values(): + self.run_epoch_func(results, opt_metric, func_name, *args, **kwargs) + else: + self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) + return results + + def get_epoch_pbar_metrics(self, *_, **__) -> List[Dict]: + return self.get_epoch_from_func_name("get_epoch_pbar_metrics") + + def get_epoch_log_metrics(self, *_, **__) -> List[Dict]: + return self.get_epoch_from_func_name("get_epoch_log_metrics") + + def get_forked_metrics(self, *_, **__) -> List[Dict]: + return self.get_epoch_from_func_name("get_forked_metrics") + + def append(self, result: Result, info: Dict) -> None: + dataloader_idx = info["dataloader_idx"] + self._internal_type = info["type"] + opt_idx = info["opt_idx"] + + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + if dataloader_idx not in self._internals: + self._internals_reduced[dataloader_idx] = defaultdict(dict) + self._latest_ref[dataloader_idx] = {} + self._internals.setdefault(dataloader_idx, {}) + + batch_idx = info["batch_idx"] + self._internals[dataloader_idx].setdefault(opt_idx, {}) + self._internals[dataloader_idx][opt_idx].setdefault(batch_idx, []) + self._internals[dataloader_idx][opt_idx][batch_idx].append(result) + else: + self._internals.setdefault(dataloader_idx, []) + self._internals[dataloader_idx].append(result) + self._latest_ref.setdefault(dataloader_idx, {}) + + self._latest_ref[dataloader_idx].setdefault(opt_idx, {}) + self._latest_ref[dataloader_idx][opt_idx] = result + + def auto_reduce_results_on_epoch_end(self) -> None: + """ + This function is called to reduce `self._internals` Result object. + The reduced Result object will be saved into `self._internals_reduced` + The `self._internals` stored Result objects will be deleted to save memory. + """ + if self.has_reduced: + return + for dl_idx in range(self.num_dataloaders): + epoch_metrics = self._internals[dl_idx] + + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + for opt_idx in list(epoch_metrics): + # TODO: Figure out to reduce memory + # TODO: How to start training in middle of epoch + outputs = epoch_metrics[opt_idx] + # reduce across time first + time_reduced_outputs = [] + for tbptt_outputs in outputs.values(): + tbptt_outputs = type(tbptt_outputs[0]).reduce_across_time(tbptt_outputs) + if len(tbptt_outputs) > 1: + time_reduced_outputs.append(tbptt_outputs) + + if len(time_reduced_outputs) == 0: + continue + + # reduce across training steps + outputs = type(time_reduced_outputs[0]).reduce_on_epoch_end(time_reduced_outputs) + + # with manual opt need 1 + metrics because meta is always there + if outputs.minimize is not None: + outputs.minimize = outputs.minimize.mean() + + self._internals_reduced[dl_idx][opt_idx] = outputs + + # free memory + del self._internals[dl_idx][opt_idx] + else: + reduced_epoch_metrics = epoch_metrics[0] + if len(epoch_metrics) != 1: + reduced_epoch_metrics = type(reduced_epoch_metrics).reduce_on_epoch_end(epoch_metrics) + + self._internals_reduced[dl_idx] = reduced_epoch_metrics + + # free memory + del self._internals[dl_idx] + + self.has_reduced = True + + def __getitem__(self, key: str) -> Any: + return self._internals.get(key, None) + + def __repr__(self): + return self._internals.__repr__() + + +class EpochResultStore: + """ + This class is defined for internal usage. + It holds all metrics logged using the self.log function using `HookResultStore` object. + The internal datastructure is as follow: + self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()} + Pseudo Code Example: + ``` + model._current_fx_name = 'something' + model._results = Result() + model.log('a', ...) + epoch_result_store.cache_result() + ``` + """ + + def __init__(self, trainer: 'pl.Trainer') -> None: + self.trainer = proxy(trainer) + + # Add warning only for distributed (expect rpc as main worker is running the code). + _should_warn = trainer.accelerator_connector.is_distributed + _should_warn &= not trainer.training_type_plugin.rpc_enabled + self._should_warn = _should_warn + + self.reset() + + def __getitem__(self, key: str) -> Any: + return self._internals.get(key, None) + + @property + def info(self): + """ + This function provides necessary parameters to properly configure HookResultStore obj + """ + model_ref = self.trainer.lightning_module + return { + "batch_idx": self.trainer.batch_idx, + "fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name, + "dataloader_idx": model_ref._current_dataloader_idx or 0, + "opt_idx": self._opt_idx or 0, + "split_idx": self._split_idx or 0, + "type": ( + ResultStoreType.INSIDE_BATCH_TRAIN_LOOP if self._opt_idx is not None and self._split_idx is not None + else ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP + ) + } + + def reset_model(self): + """ + This function is used to reset model state at the end of the capture + """ + model_ref = self.trainer.lightning_module + model_ref._results = Result() + model_ref._current_hook_fx_name = None + model_ref._current_fx_name = '' + + def cache_result(self) -> None: + """ + This function is called after every hook + and store the result object + """ + with self.trainer.profiler.profile("cache_result"): + model_ref = self.trainer.lightning_module + + # extract hook results + hook_result = model_ref._results + + if len(hook_result) == 1: + model_ref._current_hook_fx_name = None + model_ref._current_fx_name = '' + return + + info = self.info + fx_name = info["fx_name"] + + all_gather_fn = self.trainer.lightning_module.all_gather + self._internals.setdefault(fx_name, HookResultStore(fx_name, all_gather_fn, self._should_warn)) + + # attach capture batch_size + Result.attach_batch_size(self._batch_size, hook_result) + + hook_result = hook_result.detach() + if self.trainer.move_metrics_to_cpu: + hook_result = hook_result.cpu() + elif self.trainer._distrib_type == DistributedType.DP: + hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu)) + + self._internals[fx_name].append(hook_result, info) + + # update logged_metrics, progress_bar_metrics, callback_metrics + if "epoch_end" in fx_name: + self.update_logger_connector() + + self.reset_model() + + def update_logger_connector(self) -> Tuple[Dict, Dict]: + """ + This function is called every time we capture a hook + It automatically updates the logger_connector followings: + - progress_bar_metrics with pbar_metrics + - logged_metrics with log_metrics + - callback_metrics with progress_bar_metrics + logged_metrics + """ + + logger_connector = self.trainer.logger_connector + + callback_metrics = {} + batch_pbar_metrics = {} + batch_log_metrics = {} + + if not self._has_batch_loop_finished: + # get pbar + batch_pbar_metrics = self.get_latest_batch_pbar_metrics() + logger_connector.add_progress_bar_metrics(batch_pbar_metrics) + batch_log_metrics = self.get_latest_batch_log_metrics() + + if self.trainer.training: + logger_connector._logged_metrics.update(batch_log_metrics) + callback_metrics.update(batch_pbar_metrics) + callback_metrics.update(batch_log_metrics) + else: + # get pbar + epoch_pbar_metrics = self.get_epoch_pbar_metrics() + logger_connector.add_progress_bar_metrics(epoch_pbar_metrics) + + # get logged_metrics + epoch_log_metrics = self.get_epoch_log_metrics() + logger_connector._logged_metrics.update(epoch_log_metrics) + logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch}) + + # get forked_metrics + forked_metrics = self.get_forked_metrics() + + callback_metrics.update(epoch_pbar_metrics) + callback_metrics.update(epoch_log_metrics) + callback_metrics.update(forked_metrics) + + # TODO(carmocca): when we implement flushing the logger connector metrics after + # the trainer.state changes, this should check trainer.evaluating instead + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): + logger_connector.evaluation_callback_metrics.update(callback_metrics) + + # update callback_metrics + logger_connector._callback_metrics.update(callback_metrics) + + batch_pbar_metrics.pop("debug_epoch", None) + return batch_pbar_metrics, batch_log_metrics + + def run_batch_from_func_name(self, func_name) -> Dict: + results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] + results = [func(include_forked_originals=False) for func in results] + return {k: v for d in sum(results, []) for k, v in d.items()} # List[List[dict]] -> dict + + def get_latest_batch_log_metrics(self) -> Dict: + batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics") + batch_log_metrics.update(self.legacy_batch_log_metrics) + return batch_log_metrics + + def get_latest_batch_pbar_metrics(self) -> Dict: + batch_pbar_metrics = self.run_batch_from_func_name("get_batch_pbar_metrics") + batch_pbar_metrics.update(self.legacy_batch_pbar_metrics) + return batch_pbar_metrics + + @property + def has_reduced(self) -> bool: + hook_results = self._internals.values() + return len(hook_results) == sum(h.has_reduced for h in hook_results) + + def auto_reduce_results_on_epoch_end(self) -> None: + if not self.has_reduced: + for hook_result in self._internals.values(): + hook_result.auto_reduce_results_on_epoch_end() + + @property + def has_batch_loop_finished(self) -> bool: + return self._has_batch_loop_finished + + @has_batch_loop_finished.setter + def has_batch_loop_finished(self, has_batch_loop_finished): + if has_batch_loop_finished: + # If batch loop has finished, reduce metrics + self.auto_reduce_results_on_epoch_end() + + # batch_size should be none as we finished batch loop + self._batch_size = None + + self._has_batch_loop_finished = has_batch_loop_finished + self.update_logger_connector() + + def run_epoch_by_func_name(self, func_name) -> Dict: + if not self.has_reduced: + self.auto_reduce_results_on_epoch_end() + results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] + results = [func() for func in results] + return {k: v for d in sum(results, []) for k, v in d.items()} # List[List[dict]] -> dict + + def get_epoch_pbar_metrics(self) -> Dict: + return self.run_epoch_by_func_name("get_epoch_pbar_metrics") + + def get_epoch_log_metrics(self) -> Dict: + return self.run_epoch_by_func_name("get_epoch_log_metrics") + + def get_forked_metrics(self) -> Dict: + return self.run_epoch_by_func_name("get_forked_metrics") + + def reset(self): + self._internals = {} + self._dataloader_idx: Optional[int] = None + self._split_idx: Optional[int] = None + self._opt_idx: Optional[int] = None + self._batch_size: Optional[int] = None + self._has_batch_loop_finished = False + self.legacy_batch_log_metrics = {} + self.legacy_batch_pbar_metrics = {} + + def __call__( + self, + fx_name: str, + dl_idx: Optional[int] = None, + opt_idx: Optional[int] = None, + batch_idx: Optional[int] = None, + split_idx: Optional[int] = None, + reduced: bool = False, + ): + """ + This function is an helper to access stored data + + It access data from the HookResultStore. Please, + check its data structure for better understanding + + Data can be accessed with the following chains: + + IF REDUCED: + * IF accessing a fx_name defined in batch training loop: + fx_name -> dl_idx -> opt_idx -> batch_idx -> split_idx + * ELSE fx_name -> dl_idx -> batch_idx + ELSE: + * IF accessing a fx_name defined in batch training loop: + fx_name -> dl_idx -> opt_idx + * ELSE fx_name -> dl_idx + + Note: + As soon as a param is None, it breaks the chain and returns associated stored data. + + Example:: + + result: Result = self(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True) + result['train_loss_epoch'] # aggregated train_loss over one epoch. + + Args: + + fx_name: Hook name from ModelHooks or Callback. Example: ``"training_step"`` + + dl_idx: Dataloader index in short. From ``0`` to ``num_dataloaders - 1`` + + opt_idx: Optimizer index in short. From ``0`` to ``num_optimizers - 1`` + + batch_idx: Batch index seen during batch training or evaluation. + Works only with ``reduced=False`` + + split_idx: Index of split idx in training loop when ttbt is used. + + reduced: Data are being aggregated on on_epoch_end. + Indicates if we want to access the aggregated Result or not. + """ + hook_result = self[fx_name] + internal_type = hook_result._internal_type + result = hook_result._internals_reduced if reduced else hook_result._internals + + if dl_idx is not None: + result = result[dl_idx] + if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + if opt_idx is not None: + result = result[opt_idx] + if not reduced and batch_idx is not None: + result = result[batch_idx] + if split_idx is not None: + result = result[split_idx] + elif not reduced and batch_idx is not None: + result = result[batch_idx] + return result + + def __repr__(self): + return f"{self.__class__.__name__}(internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py new file mode 100644 index 00000000000000..6752411fcbfcd4 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -0,0 +1,458 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy +from pprint import pprint +from typing import Dict, Iterable, Optional, Union + +import torch + +from pytorch_lightning.core import memory +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder +from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.utilities import DeviceType, flatten_dict +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden + + +class LoggerConnector: + + def __init__(self, trainer, log_gpu_memory: Optional[str] = None): + self.trainer = trainer + self.log_gpu_memory = log_gpu_memory + self._callback_metrics = MetricsHolder() + self._evaluation_callback_metrics = MetricsHolder(to_float=True) + self._logged_metrics = MetricsHolder() + self._progress_bar_metrics = MetricsHolder(to_float=True) + self.eval_loop_results = [] + self._cached_results = {stage: EpochResultStore(trainer) for stage in RunningStage} + self._cached_results[None] = EpochResultStore(trainer) + self._callback_hook_validator = CallbackHookNameValidator() + + @property + def callback_metrics(self) -> Dict: + return self.get_metrics("callback_metrics") + + @callback_metrics.setter + def callback_metrics(self, callback_metrics: Dict) -> None: + self.set_metrics("callback_metrics", callback_metrics) + + @property + def evaluation_callback_metrics(self) -> Dict: + return self.get_metrics("evaluation_callback_metrics") + + @evaluation_callback_metrics.setter + def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None: + self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics) + + @property + def logged_metrics(self) -> Dict: + return self.get_metrics("logged_metrics") + + @logged_metrics.setter + def logged_metrics(self, logged_metrics: Dict) -> None: + self.set_metrics("logged_metrics", logged_metrics) + + @property + def progress_bar_metrics(self) -> Dict: + return self.get_metrics("progress_bar_metrics") + + @progress_bar_metrics.setter + def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: + self.set_metrics("progress_bar_metrics", progress_bar_metrics) + + @property + def cached_results(self) -> Union[EpochResultStore, None]: + return self._cached_results.get(self.trainer._running_stage) + + def get_metrics(self, key: str) -> Dict: + metrics_holder: MetricsHolder = getattr(self, f"_{key}") + model = self.trainer.lightning_module + metrics_holder.convert(model.device if model is not None else None) + return metrics_holder.metrics + + def set_metrics(self, key: str, val: Dict) -> None: + metrics_holder: MetricsHolder = getattr(self, f"_{key}") + metrics_holder.reset(val) + + def reset(self) -> None: + self.cached_results.reset() + + def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: + self._callback_hook_validator.check_logging_in_callbacks( + current_hook_fx_name=hook_fx_name, on_step=on_step, on_epoch=on_epoch + ) + + def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders): + model = self.trainer.lightning_module + # set dataloader_idx only if multiple ones + model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None + # track batch_size + self.cached_results._batch_size = Result.extract_batch_size(batch) + + def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None: + self.cached_results._split_idx = split_idx + self.cached_results._opt_idx = opt_idx + self.cached_results._batch_size = Result.extract_batch_size(split_batch) + + def on_train_batch_end(self) -> None: + self.cached_results._split_idx = None + self.cached_results._opt_idx = None + self.cached_results._batch_size = None + + def cache_logged_metrics(self): + self._cached_results[self.trainer._running_stage].cache_result() + + def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): + # logging + self.configure_logger(logger) + self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps + self.trainer.log_every_n_steps = log_every_n_steps + self.trainer.move_metrics_to_cpu = move_metrics_to_cpu + self.trainer.split_idx = None + + @property + def should_flush_logs(self): + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop + + @property + def should_update_logs(self): + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop + + def configure_logger(self, logger): + if logger is True: + version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) + + # default logger + self.trainer.logger = TensorBoardLogger( + save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' + ) + elif logger is False: + self.trainer.logger = None + else: + if isinstance(logger, Iterable): + self.trainer.logger = LoggerCollection(logger) + else: + self.trainer.logger = logger + + def cache_training_step_metrics(self, opt_closure_result): + """ + This function is responsible to update + logger_connector internals metrics holder based for depreceated logging + """ + using_results_obj = isinstance(opt_closure_result.training_step_output, Result) + + # temporary dict to collect metrics + logged_metrics_tmp = {} + pbar_metrics_tmp = {} + callback_metrics_tmp = {} + + if using_results_obj: + batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics( + include_forked_originals=False + ) + logged_metrics_tmp.update(batch_log_metrics) + + batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics( + include_forked_originals=False + ) + pbar_metrics_tmp.update(batch_pbar_metrics) + + forked_metrics = opt_closure_result.training_step_output.get_forked_metrics() + callback_metrics_tmp.update(forked_metrics) + callback_metrics_tmp.update(logged_metrics_tmp) + + else: + batch_log_metrics = opt_closure_result.training_step_output.log_metrics + logged_metrics_tmp.update(batch_log_metrics) + + batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end + pbar_metrics_tmp.update(batch_pbar_metrics) + + # track progress bar metrics + if len(pbar_metrics_tmp) > 0: + self.add_progress_bar_metrics(pbar_metrics_tmp) + + self._callback_metrics.update(callback_metrics_tmp) + + # save legacy log metrics + self._logged_metrics.update(logged_metrics_tmp) + self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp) + + def log_metrics(self, metrics, grad_norm_dic, step=None): + """Logs the metric dict passed in. + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + + Args: + metrics (dict): Metric values + grad_norm_dic (dict): Gradient norms + step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` + """ + # add gpu memory + if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: + mem_map = memory.get_memory_profile(self.log_gpu_memory) + metrics.update(mem_map) + + # add norms + metrics.update(grad_norm_dic) + + # turn all tensors to scalars + scalar_metrics = self.trainer.metrics_to_scalars(metrics) + + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + + elif step is None: + # added metrics by Lightning for convenience + scalar_metrics['epoch'] = self.trainer.current_epoch + step = self.trainer.global_step + + # log actual metrics + if self.trainer.logger is not None: + if self.trainer.is_global_zero: + self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) + self.trainer.logger.save() + + # track the logged metrics + self.logged_metrics.update(scalar_metrics) + self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics) + + def add_progress_bar_metrics(self, metrics): + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + + self._progress_bar_metrics.metrics[k] = v + + self.trainer.dev_debugger.track_pbar_metrics_history(metrics) + + def evaluation_epoch_end(self): + # reset dataloader idx + model_ref = self.trainer.lightning_module + model_ref._current_dataloader_idx = None + + # setting `has_batch_loop_finished` to True + # will perform Results reduction accross entire epoch. + self.cached_results.has_batch_loop_finished = True + + def add_to_eval_loop_results(self, dl_idx, has_been_initialized): + callback_metrics = deepcopy(self.evaluation_callback_metrics) + for key in list(callback_metrics.keys()): + if "dataloader_idx" in key: + if f"dataloader_idx_{dl_idx}" not in key: + # remove dl_idx from self.callback_metrics not belonging to this dataset. + del callback_metrics[key] + if has_been_initialized: + self.eval_loop_results[dl_idx].update(callback_metrics) + else: + self.eval_loop_results.append(callback_metrics) + + def prepare_eval_loop_results(self): + num_dataloaders = self.trainer.evaluation_loop.num_dataloaders + has_been_initialized = len(self.eval_loop_results) == num_dataloaders + for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): + self.add_to_eval_loop_results(dl_idx, has_been_initialized) + + def get_evaluate_epoch_results(self): + if not self.trainer.sanity_checking: + # log all the metrics as a single dict + metrics_to_log = self.cached_results.get_epoch_log_metrics() + if len(metrics_to_log) > 0: + self.log_metrics(metrics_to_log, {}) + + self.prepare_eval_loop_results() + + # log results of evaluation + if ( + self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero + and self.trainer.verbose_evaluate + ): + print('-' * 80) + for result_idx, results in enumerate(self.eval_loop_results): + print(f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS') + pprint({ + k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v + for k, v in results.items() + }) + print('-' * 80) + + results = self.eval_loop_results + + # clear mem + self.eval_loop_results = [] + return results + + def _track_callback_metrics(self, eval_results): + if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)): + return + + flat = {} + if isinstance(eval_results, list): + for eval_result in eval_results: + # with a scalar return, auto set it to "val_loss" for callbacks + if isinstance(eval_result, torch.Tensor): + flat = {'val_loss': eval_result} + elif isinstance(eval_result, dict): + flat = flatten_dict(eval_result) + + self.trainer.logger_connector.callback_metrics.update(flat) + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): + self.trainer.logger_connector.evaluation_callback_metrics.update(flat) + else: + # with a scalar return, auto set it to "val_loss" for callbacks + if isinstance(eval_results, torch.Tensor): + flat = {'val_loss': eval_results} + else: + flat = flatten_dict(eval_results) + + self.trainer.logger_connector.callback_metrics.update(flat) + if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING): + self.trainer.logger_connector.evaluation_callback_metrics.update(flat) + + def on_train_epoch_end(self): + # inform cached logger connector epoch finished + self.cached_results.has_batch_loop_finished = True + + def log_train_epoch_end_metrics(self, epoch_output, num_optimizers): + # epoch output is a list. Each item in that list has all the outputs per optimizer + # epoch_output[optimizer_idx][training_step_idx][tbptt_index] + # remember that not using truncated backprop is equivalent with truncated back prop of len(1) + + model = self.trainer.lightning_module + + # lightning module hook + self.training_epoch_end(model, epoch_output, num_optimizers) + + # log/aggregate metrics automatically + epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) + + # it will perform reduction over epoch and return log metrics + cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics() + cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics() + + # update + epoch_log_metrics.update(cached_epoch_log_metrics) + epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics) + + # -------------------------- + # track results + # -------------------------- + # add the metrics to the loggers and callbacks + if epoch_log_metrics and len(epoch_log_metrics) > 0: + self.log_metrics(epoch_log_metrics, {}) + self._callback_metrics.update(epoch_log_metrics) + + # add metrics to progress_bar and callbacks + if len(epoch_progress_bar_metrics) > 0: + self.add_progress_bar_metrics(epoch_progress_bar_metrics) + self._callback_metrics.update(epoch_progress_bar_metrics) + + # reset epoch loop result for next epoch + self.cached_results.reset() + + def training_epoch_end(self, model, epoch_output, num_optimizers): + if not is_overridden('training_epoch_end', model=model): + return + + # run training_epoch_end + # refresh the result for custom logging at the epoch level + model._current_fx_name = 'training_epoch_end' + epoch_output = self.__prepare_epoch_end_inputs(epoch_output) + + if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization: + epoch_output = epoch_output[0] + + # lightningmodule hook + epoch_output = model.training_epoch_end(epoch_output) + + if epoch_output is not None: + raise MisconfigurationException( + 'training_epoch_end expects a return of None. ' + 'HINT: remove the return statement in training_epoch_end' + ) + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + + def __auto_reduce_results_on_epoch_end(self, epoch_output): + epoch_log_metrics = {} + epoch_progress_bar_metrics = {} + for opt_outputs in epoch_output: + # reduce across time first + time_reduced_outputs = [] + for tbptt_outs in opt_outputs: + tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) + if len(tbptt_outs) > 1: + time_reduced_outputs.append(tbptt_outs) + + if len(time_reduced_outputs) == 0: + continue + + # reduce across training steps + opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) + + # with manual opt need 1 + metrics because meta is always there + if opt_outputs.minimize is not None: + opt_outputs.minimize = opt_outputs.minimize.mean() + epoch_log_metrics.update(opt_outputs.epoch_log_metrics) + epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics) + + return epoch_log_metrics, epoch_progress_bar_metrics + + def __prepare_epoch_end_inputs(self, epoch_output): + """ + Pulls out only the "extra" information for epoch end + + Return: + a single list, each element per optimizer then batch then time + """ + gathered_epoch_outputs = [] + for opt_outputs in epoch_output: + # gather across time first + time_gathered_outputs = [] + for tbptt_outs in opt_outputs: + result = [] + for x in tbptt_outs: + out = x.extra + out['loss'] = x.minimize + result.append(out) + + # when time = 0, pass in the literal dict instead of array + if len(result) == 1: + result = result[0] + time_gathered_outputs.append(result) + + gathered_epoch_outputs.append(time_gathered_outputs) + + return gathered_epoch_outputs + + def log_train_step_metrics(self, batch_output): + if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization: + return + _, batch_log_metrics = self.cached_results.update_logger_connector() + # when metrics should be logged + if self.should_update_logs or self.trainer.fast_dev_run is True: + # logs user requested information to logger + grad_norm_dic = batch_output.grad_norm_dic + if grad_norm_dic is None: + grad_norm_dic = {} + if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: + self.log_metrics(batch_log_metrics, grad_norm_dic) + self._callback_metrics.update(batch_log_metrics) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py new file mode 100644 index 00000000000000..1efbcc638674fc --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +from typing import Any, Dict, Optional, Union + +import torch +from torchmetrics import Metric + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any] + + +class MetricsHolder: + """ + This class acts as a dictionary holder. + It holds metrics and implements conversion functions. + Those functions will be triggered within LoggerConnector + when the property is being requested from the user. + """ + + def __init__(self, to_float: bool = False) -> None: + self.metrics: Dict[str, _METRIC_TYPE] = {} + self._to_float = to_float + + def update(self, metrics: dict) -> None: + self.metrics.update(metrics) + + def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE: + return self.metrics.pop(key, default) + + def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None: + self.metrics = metrics + + def convert(self, device: Optional[torch.device]) -> None: + for key, value in self.metrics.items(): + if self._to_float: + if isinstance(value, torch.Tensor) and value.numel() != 1: + raise MisconfigurationException( + f"The metric `{key}` does not contain a single element" + f" thus it cannot be converted to float. Found `{value}`" + ) + converted = self._convert_to_float(value) + else: + converted = self._convert_to_tensor(value, device) + self.metrics[key] = converted + + @staticmethod + def _convert_to_float(current: _METRIC_TYPE) -> float: + if isinstance(current, Metric): + current = current.compute().detach() + + if isinstance(current, torch.Tensor): + current = float(current.item()) + + elif isinstance(current, int): + current = float(current) + + return current + + @staticmethod + def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor: + if isinstance(current, Metric): + current = current.compute().detach() + + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=device, dtype=torch.float) + + if isinstance(current, torch.Tensor) and current.device.type == "xla": + current = current.cpu() + + return current diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py new file mode 100644 index 00000000000000..cdaab6248f0066 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Root module for all distributed operations in Lightning. +Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. + +""" +from weakref import proxy + + +class ModelConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def copy_trainer_model_properties(self, model): + ref_model = self.trainer.lightning_module or model + + automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization + self.trainer.train_loop.automatic_optimization = automatic_optimization + + for m in [model, ref_model]: + m.trainer = proxy(self.trainer) + m._device_type = str(self.trainer._device_type) + m._distrib_type = str(self.trainer._distrib_type) + m.use_amp = self.trainer.amp_backend is not None + m.precision = self.trainer.precision diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py new file mode 100644 index 00000000000000..a50603bb58dbf8 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class OptimizerConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init(self): + self.trainer.lr_schedulers = [] + self.trainer.optimizers = [] + self.trainer.optimizer_frequencies = [] + + def update_learning_rates(self, interval: str, monitor_metrics=None): + """Update learning rates. + + Args: + interval: either 'epoch' or 'step'. + monitor_metrics: dict of possible values to monitor + """ + if not self.trainer.lr_schedulers: + return + + for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers): + current_idx = self.trainer.batch_idx if interval == 'step' else self.trainer.current_epoch + current_idx += 1 # account for both batch and epoch starts from 0 + # Take step if call to update_learning_rates matches the interval key and + # the current step modulo the schedulers frequency is zero + if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0: + # If instance of ReduceLROnPlateau, we need a monitor + monitor_key, monitor_val = None, None + if lr_scheduler['reduce_on_plateau']: + monitor_key = lr_scheduler['monitor'] + monitor_val = ( + monitor_metrics.get(monitor_key) if monitor_metrics is not None else + self.trainer.logger_connector.callback_metrics.get(monitor_key) + ) + if monitor_val is None: + if lr_scheduler.get('strict', True): + avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys()) + raise MisconfigurationException( + f'ReduceLROnPlateau conditioned on metric {monitor_key}' + f' which is not available. Available metrics are: {avail_metrics}.' + ' Condition can be set using `monitor` key in lr scheduler dict' + ) + rank_zero_warn( + f'ReduceLROnPlateau conditioned on metric {monitor_key}' + ' which is not available but strict is set to `False`.' + ' Skipping learning rate update.', + RuntimeWarning, + ) + continue + # update LR + old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + + if lr_scheduler['reduce_on_plateau']: + lr_scheduler['scheduler'].step(monitor_val) + else: + lr_scheduler['scheduler'].step() + + new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + + if self.trainer.dev_debugger.enabled: + self.trainer.dev_debugger.track_lr_schedulers_update( + self.trainer.batch_idx, + interval, + scheduler_idx, + old_lr, + new_lr, + monitor_key=monitor_key, + monitor_val=monitor_val + ) diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py new file mode 100644 index 00000000000000..fa1002d70a7ced --- /dev/null +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from typing import Union +from weakref import proxy + +from pytorch_lightning.profiler import ( + AdvancedProfiler, + BaseProfiler, + PassThroughProfiler, + PyTorchProfiler, + SimpleProfiler, +) +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, + "pytorch": PyTorchProfiler, +} + + +class ProfilerConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init(self, profiler: Union[BaseProfiler, str]): + + if profiler and not isinstance(profiler, (str, BaseProfiler)): + raise MisconfigurationException( + "Only None, str and subclasses of `BaseProfiler`" + " are valid values for `Trainer`'s `profiler` parameter." + f" Received {profiler} which is of type {type(profiler)}." + ) + if isinstance(profiler, str): + if profiler.lower() in PROFILERS: + profiler_class = PROFILERS[profiler.lower()] + profiler = profiler_class() + else: + raise ValueError( + "When passing string value for the `profiler` parameter of" + " `Trainer`, it can only be 'simple' or 'advanced'" + ) + self.trainer.profiler = profiler or PassThroughProfiler() + + def setup(self) -> None: + trainer = self.trainer + local_rank = trainer.local_rank if trainer.world_size > 1 else None + trainer.profiler._lightning_module = proxy(trainer.lightning_module) + trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py new file mode 100644 index 00000000000000..54529e3346f0f9 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -0,0 +1,61 @@ +import logging +import os +import signal +from subprocess import call + +log = logging.getLogger(__name__) + + +class SLURMConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def register_slurm_signal_handlers(self): + # see if we're using slurm (not interactive) + on_slurm = False + try: + job_name = os.environ['SLURM_JOB_NAME'] + if job_name != 'bash': + on_slurm = True + # todo: specify the possible exception + except Exception: + pass + + if on_slurm: + log.info('Set SLURM handle signals.') + signal.signal(signal.SIGUSR1, self.sig_handler) + signal.signal(signal.SIGTERM, self.term_handler) + + def sig_handler(self, signum, frame): # pragma: no-cover + if self.trainer.is_global_zero: + # save weights + log.info('handling SIGUSR1') + self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) + + # find job id + job_id = os.environ['SLURM_JOB_ID'] + cmd = ['scontrol', 'requeue', job_id] + + # requeue job + log.info(f'requeing job {job_id}...') + try: + result = call(cmd) + except FileNotFoundError: + # This can occur if a subprocess call to `scontrol` is run outside a shell context + # Re-attempt call (now with shell context). If any error is raised, propagate to user. + # When running a shell command, it should be passed as a single string. + joint_cmd = [str(x) for x in cmd] + result = call(' '.join(joint_cmd), shell=True) + + # print result text + if result == 0: + log.info(f'requeued exp {job_id}') + else: + log.warning('requeue failed...') + + # close experiment to avoid issues + self.trainer.logger.close() + + def term_handler(self, signum, frame): # pragma: no-cover + log.info("bypassing sigterm") diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py new file mode 100644 index 00000000000000..dd7aad8cd6d883 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class TrainingTricksConnector: + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init( + self, + gradient_clip_val, + track_grad_norm, + accumulate_grad_batches, + truncated_bptt_steps, + terminate_on_nan, + ): + + self.trainer.terminate_on_nan = terminate_on_nan + + # gradient clipping + self.trainer.gradient_clip_val = gradient_clip_val + + # gradient norm tracking + if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': + raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).") + self.trainer.track_grad_norm = float(track_grad_norm) + + # accumulated grads + self.trainer.accumulate_grad_batches = accumulate_grad_batches + self.configure_accumulated_gradients(accumulate_grad_batches) + + self.trainer.truncated_bptt_steps = truncated_bptt_steps + + def configure_accumulated_gradients(self, accumulate_grad_batches): + if isinstance(accumulate_grad_batches, dict): + self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) + elif isinstance(accumulate_grad_batches, int): + schedule = {0: accumulate_grad_batches} + self.trainer.accumulation_scheduler = GradientAccumulationScheduler(schedule) + else: + raise TypeError("Gradient accumulation supports only int and dict types") diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 52e53acd5a2b49..59ec40c3df2e87 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,151 +1,189 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import multiprocessing import platform -from abc import ABC, abstractmethod -from typing import Union, List, Tuple, Callable +from abc import ABC +from copy import deepcopy +from typing import Iterable, List, Tuple, Union -import torch.distributed as torch_distrib -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core import LightningModule +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException - -try: - from torch.utils.data import IterableDataset - ITERABLE_DATASET_EXISTS = True -except ImportError: - ITERABLE_DATASET_EXISTS = False - -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True - -try: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - - -def _has_len(dataloader: DataLoader) -> bool: - """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader """ - try: - # try getting the length - if len(dataloader) == 0: - raise ValueError('`Dataloader` returned 0 length.' - ' Please make sure that your Dataloader at least returns 1 batch') - return True - except TypeError: - return False +from pytorch_lightning.utilities.model_helpers import is_overridden class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - proc_rank: int - use_ddp: bool - use_ddp2: bool - use_horovod: bool - shown_warnings: ... val_check_interval: float - use_tpu: bool tpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] - val_check_batch: ... + val_check_batch: float val_dataloaders: List[DataLoader] - num_val_batches: Union[int, float] + num_val_batches: List[Union[int, float]] test_dataloaders: List[DataLoader] - num_test_batches: Union[int, float] - train_percent_check: float - val_percent_check: float - test_percent_check: float - replace_sampler_ddp: bool - - @abstractmethod - def is_overriden(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def _percent_range_check(self, name: str) -> None: - value = getattr(self, name) - msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.' - if name == 'val_check_interval': - msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.' - - if not 0. <= value <= 1.: - raise ValueError(msg) + num_test_batches: List[Union[int, float]] + limit_train_batches: Union[int, float] + overfit_batches: Union[int, float] + distributed_sampler_kwargs: dict + accelerator: Accelerator + accelerator_connector: AcceleratorConnector + dev_debugger: InternalDebugger def _worker_check(self, dataloader: DataLoader, name: str) -> None: on_windows = platform.system() == 'Windows' - if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2 and not on_windows: - rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.' - ' Consider increasing the value of the `num_workers` argument`' - ' in the `DataLoader` init to improve performance.') + # ddp_spawn + num_workers > 0 don't mix! tell the user + is_dataloader = isinstance(dataloader, DataLoader) + using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn" + if is_dataloader and not on_windows: + if dataloader.num_workers > 0 and using_spawn: + rank_zero_warn( + 'Dataloader(num_workers>0) and ddp_spawn do not mix well!' + ' Your performance might suffer dramatically.' + ' Please consider setting accelerator=ddp to use num_workers > 0' + ' (this is a bottleneck of Python .spawn() and PyTorch' + ) + + elif dataloader.num_workers == 0 and using_spawn: + rank_zero_warn( + 'You are using `accelerator=ddp_spawn` with num_workers=0.' + ' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`' + ) + + elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn: + num_cpus = multiprocessing.cpu_count() + rank_zero_warn( + f'The dataloader, {name}, does not have many workers which may be a bottleneck.' + ' Consider increasing the value of the `num_workers` argument`' + f' (try {num_cpus} which is the number of cpus on this machine)' + f' in the `DataLoader` init to improve performance.' + ) - def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: + def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: # don't do anything if it's not a dataloader - # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) - - is_iterable_ds = False - if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): - is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) + # don't manipulate iterable datasets + is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader - need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) - if self.replace_sampler_ddp and need_dist_sampler: - skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] + need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance( + dataloader.sampler, DistributedSampler + ) + if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: + if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): + raise MisconfigurationException( + 'You seem to have configured a sampler in your DataLoader. This will be replaced ' + ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' + ' distributed training. Either remove the sampler from your DataLoader or set' + ' `replace_sampler_ddp`=False if you want to use your custom sampler.' + ) - dl_args = { - k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys - } + # replace with distributed sampler + sampler = self._get_distributed_sampler(dataloader, shuffle) + dataloader = self.replace_sampler(dataloader, sampler) - if self.use_tpu: - sampler = DistributedSampler( - dataloader.dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - ) - elif self.use_horovod: - sampler = DistributedSampler(dataloader.dataset, - num_replicas=hvd.size(), - rank=hvd.rank()) - else: - world_size = { - 'ddp': self.num_nodes * self.num_processes, - 'ddp2': self.num_nodes, - 'ddp_cpu': self.num_processes * self.num_nodes - } - sampler = DistributedSampler( - dataloader.dataset, - num_replicas=world_size.get(self.distributed_backend, 0), - rank=self.proc_rank, - ) + return dataloader + @staticmethod + def _resolve_batch_sampler(dl_args, dataloader, sampler): + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and type(batch_sampler) is not BatchSampler: + batch_sampler = type(batch_sampler)( + sampler, + batch_size=batch_sampler.batch_size, + drop_last=batch_sampler.drop_last, + ) + dl_args['batch_sampler'] = batch_sampler + dl_args['batch_size'] = 1 + dl_args['shuffle'] = False + dl_args['sampler'] = None + dl_args['drop_last'] = False + else: dl_args['sampler'] = sampler - dataloader = type(dataloader)(**dl_args) - + dl_args['shuffle'] = False + dl_args['batch_sampler'] = None + + return dl_args + + def replace_sampler(self, dataloader, sampler): + skip_keys = ('sampler', 'batch_sampler', 'dataset_kind') + skip_signature_keys = ('args', 'kwargs', 'self') + + attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} + + params = set(inspect.signature(dataloader.__init__).parameters) + contains_dataset = True + + if type(dataloader) is not DataLoader: + contains_dataset = "dataset" in params + params.update(inspect.signature(DataLoader.__init__).parameters) + + dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys} + + dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler) + + multiprocessing_context = dataloader.multiprocessing_context + dl_args['multiprocessing_context'] = multiprocessing_context + + missing_kwargs = params.difference(skip_signature_keys).difference(dl_args) + if missing_kwargs: + """ + Example: + class CustomDataLoader(DataLoader): + def __init__(self, num_features, dataset, *args, **kwargs): + self.num_features = num_features + super().__init__(dataset, *args, **kwargs) + """ + dataloader_cls_name = dataloader.__class__.__name__ + raise MisconfigurationException( + f"Trying to inject DistributedSampler within {dataloader_cls_name} class." + "This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. " + f"Missing attributes are {missing_kwargs}. " + f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ", + "manually add DistributedSampler as " + f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).", + ) + + if not contains_dataset: + dl_args.pop('dataset') + + dataloader = type(dataloader)(**dl_args) + dataloader.multiprocessing_context = multiprocessing_context return dataloader + def _get_distributed_sampler(self, dataloader, shuffle): + kwargs = self.distributed_sampler_kwargs + kwargs['shuffle'] = shuffle and not self.overfit_batches + sampler = DistributedSampler(dataloader.dataset, **kwargs) + return sampler + def reset_train_dataloader(self, model: LightningModule) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). @@ -153,22 +191,44 @@ def reset_train_dataloader(self, model: LightningModule) -> None: Args: model: The current `LightningModule` """ - self.train_dataloader = self.request_dataloader(model.train_dataloader) + self.train_dataloader = self.request_dataloader(model, "train") + + if self.overfit_batches > 0: + if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): + rank_zero_warn( + 'You requested to overfit but enabled training dataloader shuffling.' + ' We are turning it off for you.' + ) + self.train_dataloader = self.replace_sampler( + self.train_dataloader, SequentialSampler(self.train_dataloader.dataset) + ) - self.num_training_batches = 0 + # debugging + self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) # automatically add samplers - self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) + self.train_dataloader = apply_to_collection( + self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True + ) - self._worker_check(self.train_dataloader, 'train dataloader') - self._percent_range_check('train_percent_check') + # check the workers recursively + apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') - if not _has_len(self.train_dataloader): - self.num_training_batches = float('inf') - else: - # try getting the length - self.num_training_batches = len(self.train_dataloader) - self.num_training_batches = int(self.num_training_batches * self.train_percent_check) + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches + self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) + + self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') + + if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: + self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) + elif self.num_training_batches != float('inf'): + self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) + elif self.limit_train_batches != 1.0: + raise MisconfigurationException( + 'When using an IterableDataset for `limit_train_batches`,' + ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' + ' `num_training_batches` to use.' + ) # determine when to check validation # if int passed in, val checks that often @@ -179,24 +239,27 @@ def reset_train_dataloader(self, model: LightningModule) -> None: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' - 'If you want to disable validation set `val_percent_check` to 0.0 instead.') + 'If you want to disable validation set `limit_val_batches` to 0.0 instead.' + ) else: - if not _has_len(self.train_dataloader): + if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset' - ' or when DataLoader does not implement `__len__`) for `train_dataloader`,' + 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' - ' checking validation every k training batches.') + ' checking validation every k training batches.' + ) else: - self._percent_range_check('val_check_interval') - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]: + def _reset_eval_dataloader( + self, + model: LightningModule, + mode: str, + ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: @@ -206,44 +269,84 @@ def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int Returns: Tuple (num_batches, dataloaders) """ - dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader')) + # always get the loaders first so we can count how many there are + loader_name = f'{mode}_dataloader' + dataloaders = self.request_dataloader(model, mode) if not isinstance(dataloaders, list): dataloaders = [dataloaders] - # shuffling in val and test set is bad practice - for loader in dataloaders: - if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): - raise MisconfigurationException( - f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' - ' this off for validation and test dataloaders.') + # when overfitting use the training loader as val and test + # duplicate it the numb of times needed to match the train loaders + if self.overfit_batches > 0: + num_loaders = len(dataloaders) + train_dataloader = self.request_dataloader(model, 'train') + dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] + + self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) + + for loader_i in range(len(dataloaders)): + loader = dataloaders[loader_i] + + # shuffling in val and test set is bad practice + modes = ('val', 'test', 'predict') + if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): + + # when overfitting, the dataloader should not have sampler + if self.overfit_batches > 0 and mode != 'predict': + rank_zero_warn( + 'You requested to overfit but enabled val/test dataloader shuffling.' + ' We are turning it off for you.' + ) + dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) + + else: + rank_zero_warn( + f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' + ' this off for val/test/predict dataloaders.' + ) + + if any([dl is None for dl in dataloaders]): + rank_zero_warn("One of given dataloaders is None and it will be skipped.") # add samplers - dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None] + dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None] - num_batches = 0 + loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): + num_batches = len(dataloader) if has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') - if not _has_len(dataloader): - num_batches = float('inf') - percent_check = getattr(self, f'{mode}_percent_check') + # percent or num_steps + limit_eval_batches = getattr(self, f'limit_{mode}_batches') - if num_batches != float('inf'): - self._percent_range_check(f'{mode}_percent_check') + # limit num batches either as a percent or num steps + if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: + num_batches = min(num_batches, int(limit_eval_batches)) + elif num_batches != float('inf'): + num_batches = int(num_batches * limit_eval_batches) + elif limit_eval_batches != 1.0: + raise MisconfigurationException( + 'When using an IterableDataset for `limit_{mode}_batches`,' + f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies' + f' `num_{mode}_batches` to use.' + ) - num_batches = sum(len(dataloader) for dataloader in dataloaders) - num_batches = int(num_batches * percent_check) - elif percent_check not in (0.0, 1.0): - raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset' - f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,' - f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.') - return num_batches, dataloaders + if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): + min_pct = 1.0 / len(dataloader) + raise MisconfigurationException( + f'you requested to check {limit_eval_batches} of the {mode} dataloader but' + f' {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches.' + f' Try at least limit_{mode}_batches={min_pct}' + ) + + loader_num_batches.append(num_batches) + + return loader_num_batches, dataloaders def reset_val_dataloader(self, model: LightningModule) -> None: """Resets the validation dataloader and determines the number of batches. @@ -251,21 +354,33 @@ def reset_val_dataloader(self, model: LightningModule) -> None: Args: model: The current `LightningModule` """ - if self.is_overriden('validation_step'): - self.num_val_batches, self.val_dataloaders = \ - self._reset_eval_dataloader(model, 'val') + has_loader = is_overridden('val_dataloader', model) + has_step = is_overridden('validation_step', model) + if has_loader and has_step: + self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val') def reset_test_dataloader(self, model) -> None: - """Resets the validation dataloader and determines the number of batches. + """Resets the test dataloader and determines the number of batches. + + Args: + model: The current `LightningModule` + """ + has_loader = is_overridden('test_dataloader', model) + has_step = is_overridden('test_step', model) + if has_loader and has_step: + self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(model, 'test') + + def reset_predict_dataloader(self, model) -> None: + """Resets the predict dataloader and determines the number of batches. Args: model: The current `LightningModule` """ - if self.is_overriden('test_step'): - self.num_test_batches, self.test_dataloaders =\ - self._reset_eval_dataloader(model, 'test') + has_loader = is_overridden('predict_dataloader', model) + if has_loader: + self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') - def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: + def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: @@ -274,36 +389,20 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = dataloader_fx() - - # get the function we'll use to get data - if self.use_ddp or self.use_ddp2: - # all processes wait until data download has happened - torch_distrib.barrier() - - # data download/load on TPU - elif self.use_tpu and XLA_AVAILABLE: - # all processes wait until data download has happened - torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders') - - elif self.use_horovod: - # all processes wait until data download has happened - hvd.join() - + if model.trainer is not None: + model.trainer.call_hook(f"on_{stage}_dataloader") + dataloader: DataLoader = getattr(model, f'{stage}_dataloader')() + dataloader = self._flatten_dl_only(dataloader) + self.accelerator.barrier('get_dataloaders') return dataloader - def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float, - test_percent_check: float, overfit_pct: float) -> None: - """Use less data for debugging purposes - """ - self.train_percent_check = train_percent_check - self.val_percent_check = val_percent_check - self.test_percent_check = test_percent_check - if overfit_pct > 0: - if overfit_pct > 1: - raise ValueError( - f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.') + def _flatten_dl_only(self, dataloaders): + # handles user error when they return: + # return dl1, dl2 vs return (dl1, dl2) + if isinstance(dataloaders, tuple): + all_dls = [isinstance(x, Iterable) for x in dataloaders] + all_dls = all(all_dls) + if all_dls: + dataloaders = list(dataloaders) - self.train_percent_check = overfit_pct - self.val_percent_check = overfit_pct - self.test_percent_check = overfit_pct + return dataloaders diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 2705c4f1604644..32dbc8c4088a35 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -1,137 +1,154 @@ -"""Mirroring deprecated API""" +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_deprecation + + +class DeprecatedDistDeviceAttributes: + + num_gpus: int + accelerator_connector: AcceleratorConnector -from abc import ABC - -from pytorch_lightning.utilities import rank_zero_warn + @property + def on_cpu(self) -> bool: + rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._device_type == DeviceType.CPU + @on_cpu.setter + def on_cpu(self, val: bool) -> None: + rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._device_type = DeviceType.CPU -class TrainerDeprecatedAPITillVer0_8(ABC): + @property + def on_tpu(self) -> bool: + rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._device_type == DeviceType.TPU - def __init__(self): - super().__init__() # mixin calls super too + @on_tpu.setter + def on_tpu(self, val: bool) -> None: + rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._device_type = DeviceType.TPU @property - def nb_gpu_nodes(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.num_nodes + def use_tpu(self) -> bool: + rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") + return self.on_tpu - @property - def num_gpu_nodes(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `num_gpu_nodes` has renamed to `num_nodes` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.num_nodes - - @num_gpu_nodes.setter - def num_gpu_nodes(self, num_nodes): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `num_gpu_nodes` has renamed to `num_nodes` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.num_nodes = num_nodes + @use_tpu.setter + def use_tpu(self, val: bool) -> None: + rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") + self.on_tpu = val @property - def gradient_clip(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.gradient_clip_val - - @gradient_clip.setter - def gradient_clip(self, gradient_clip): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.gradient_clip_val = gradient_clip + def on_gpu(self) -> bool: + rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._device_type == DeviceType.GPU + + @on_gpu.setter + def on_gpu(self, val: bool) -> None: + rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._device_type = DeviceType.GPU @property - def max_nb_epochs(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.max_epochs - - @max_nb_epochs.setter - def max_nb_epochs(self, max_epochs): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.max_epochs = max_epochs + def use_dp(self) -> bool: + rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._distrib_type == DistributedType.DP + + @use_dp.setter + def use_dp(self, val: bool) -> None: + rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._distrib_type = DistributedType.DP @property - def min_nb_epochs(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.min_epochs - - @min_nb_epochs.setter - def min_nb_epochs(self, min_epochs): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.min_epochs = min_epochs + def use_ddp(self) -> bool: + rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + + @use_ddp.setter + def use_ddp(self, val: bool) -> None: + rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._distrib_type = DistributedType.DDP @property - def nb_sanity_val_steps(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `nb_sanity_val_steps` has renamed to " - "`num_sanity_val_steps` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.num_sanity_val_steps - - @nb_sanity_val_steps.setter - def nb_sanity_val_steps(self, nb): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `nb_sanity_val_steps` has renamed to " - "`num_sanity_val_steps` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.num_sanity_val_steps = nb + def use_ddp2(self) -> bool: + rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._distrib_type == DistributedType.DDP2 + + @use_ddp2.setter + def use_ddp2(self, val: bool) -> None: + rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._distrib_type = DistributedType.DDP2 @property - def default_save_path(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `default_save_path` has renamed to `default_root_dir` since v0.5.x" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.default_root_dir - - @default_save_path.setter - def default_save_path(self, path): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("Attribute `default_save_path` has renamed to `default_root_dir` since v0.5.x" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.default_root_dir = path + def use_horovod(self) -> bool: + rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") + return self.accelerator_connector._distrib_type == DistributedType.HOROVOD + + @use_horovod.setter + def use_horovod(self, val: bool) -> None: + rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._distrib_type = DistributedType.HOROVOD @property - def tng_tqdm_dic(self): - """Back compatibility, will be removed in v0.8.0""" - rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.progress_bar_dict + def use_single_gpu(self) -> bool: + rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") + # todo, limiting to exclude DDP2 is not clear but it comes from connectors... + return ( + self.accelerator_connector._device_type and self.accelerator_connector._device_type == DeviceType.GPU + and self.num_gpus == 1 and self.accelerator_connector._distrib_type not in (DistributedType.DDP2, ) + ) + + @use_single_gpu.setter + def use_single_gpu(self, val: bool) -> None: + rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") + if val: + self.accelerator_connector._device_type = DeviceType.GPU -class TrainerDeprecatedAPITillVer0_9(ABC): +class DeprecatedTrainerAttributes: - def __init__(self): - super().__init__() # mixin calls super too + accelerator: Accelerator + lightning_module: LightningModule + sanity_checking: bool @property - def show_progress_bar(self): - """Back compatibility, will be removed in v0.9.0""" - rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2" - " and this method will be removed in v0.9.0", DeprecationWarning) - return self.progress_bar_refresh_rate >= 1 - - @show_progress_bar.setter - def show_progress_bar(self, tf): - """Back compatibility, will be removed in v0.9.0""" - rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2" - " and this method will be removed in v0.9.0", DeprecationWarning) + def accelerator_backend(self) -> Accelerator: + rank_zero_deprecation( + "The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`" + " since 1.2 and will be removed in v1.4." + ) + return self.accelerator + + def get_model(self) -> LightningModule: + rank_zero_deprecation( + "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" + " and will be removed in v1.4." + ) + return self.lightning_module @property - def training_tqdm_dict(self): - """Back compatibility, will be removed in v0.9.0""" - rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3" - " and this method will be removed in v0.9.0", DeprecationWarning) - return self.progress_bar_dict + def running_sanity_check(self) -> bool: + rank_zero_deprecation( + "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5." + ) + return self.sanity_checking diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py deleted file mode 100644 index 8651dd5c1b5a02..00000000000000 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ /dev/null @@ -1,449 +0,0 @@ -""" -Lightning supports model training on a cluster managed by SLURM in the following cases: - -1. Training on a single cpu or single GPU. -2. Train on multiple GPUs on the same node using DataParallel or DistributedDataParallel -3. Training across multiple GPUs on multiple different nodes via DistributedDataParallel. - -.. note:: A node means a machine with multiple GPUs - -Running grid search on a cluster --------------------------------- - -To use lightning to run a hyperparameter search (grid-search or random-search) on a cluster do 4 things: - -(1). Define the parameters for the grid search - -.. code-block:: python - - from test_tube import HyperOptArgumentParser - - # subclass of argparse - parser = HyperOptArgumentParser(strategy='random_search') - parser.add_argument('--learning_rate', default=0.002, type=float, help='the learning rate') - - # let's enable optimizing over the number of layers in the network - parser.opt_list('--nb_layers', default=2, type=int, tunable=True, options=[2, 4, 8]) - - hparams = parser.parse_args() - -.. note:: You must set `Tunable=True` for that argument to be considered in the permutation set. - Otherwise test-tube will use the default value. This flag is useful when you don't want - to search over an argument and want to use the default instead. - -(2). Define the cluster options in the - `SlurmCluster object `_ (over 5 nodes and 8 gpus) - -.. code-block:: python - - from test_tube.hpc import SlurmCluster - - # hyperparameters is a test-tube hyper params object - # see https://williamfalcon.github.io/test-tube/hyperparameter_optimization/HyperOptArgumentParser/ - hyperparams = args.parse() - - # init cluster - cluster = SlurmCluster( - hyperparam_optimizer=hyperparams, - log_path='/path/to/log/results/to', - python_cmd='python3' - ) - - # let the cluster know where to email for a change in job status (ie: complete, fail, etc...) - cluster.notify_job_status(email='some@email.com', on_done=True, on_fail=True) - - # set the job options. In this instance, we'll run 20 different models - # each with its own set of hyperparameters giving each one 1 GPU (ie: taking up 20 GPUs) - cluster.per_experiment_nb_gpus = 8 - cluster.per_experiment_nb_nodes = 5 - - # we'll request 10GB of memory per node - cluster.memory_mb_per_node = 10000 - - # set a walltime of 10 minues - cluster.job_time = '10:00' - - -(3). Make a main function with your model and trainer. Each job will call this function with a particular -hparams configuration.:: - - from pytorch_lightning import Trainer - - def train_fx(trial_hparams, cluster_manager, _): - # hparams has a specific set of hyperparams - - my_model = MyLightningModel() - - # give the trainer the cluster object - trainer = Trainer() - trainer.fit(my_model) - - ` - -(4). Start the grid/random search:: - - # run the models on the cluster - cluster.optimize_parallel_cluster_gpu( - train_fx, - nb_trials=20, - job_name='my_grid_search_exp_name', - job_display_name='my_exp') - -.. note:: `nb_trials` specifies how many of the possible permutations to use. If using `grid_search` it will use - the depth first ordering. If using `random_search` it will use the first k shuffled options. FYI, random search - has been shown to be just as good as any Bayesian optimization method when using a reasonable number of samples (60), - see this `paper `_ for more information. - -Walltime auto-resubmit ----------------------- - -Lightning automatically resubmits jobs when they reach the walltime. Make sure to set the SIGUSR1 signal in -your SLURM script.:: - - # 90 seconds before training ends - #SBATCH --signal=SIGUSR1@90 - -When lightning receives the SIGUSR1 signal it will: -1. save a checkpoint with 'hpc_ckpt' in the name. -2. resubmit the job using the SLURM_JOB_ID - -When the script starts again, Lightning will: -1. search for a 'hpc_ckpt' checkpoint. -2. restore the model, optimizers, schedulers, epoch, etc... - -""" - -import os -import re -from abc import ABC, abstractmethod -from typing import Union - -import torch -from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn - -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - - -class TrainerDDPMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - on_gpu: bool - num_gpu_nodes: int - logger: Union[LightningLoggerBase, bool] - checkpoint_callback: Union[ModelCheckpoint, bool] - data_parallel_device_ids: ... - distributed_backend: str - amp_level: str - use_tpu: bool - default_root_dir: str - use_native_amp: bool - progress_bar_callback: ... - - @property - @abstractmethod - def num_gpus(self) -> int: - """Warning: this is just empty shell for code implemented in other class.""" - - @property - @abstractmethod - def use_amp(self) -> bool: - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def copy_trainer_model_properties(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def run_pretrain_routine(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def init_optimizers(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def init_tpu(self): - # turn off all the GPU stuff - self.distributed_backend = None - - # enable tpu - self.use_tpu = True - - def set_distributed_mode(self, distributed_backend): - self.use_dp = False - self.use_ddp = False - self.use_ddp2 = False - self.use_horovod = False - self.single_gpu = False - - if distributed_backend is None: - if self.has_horovodrun(): - self._set_horovod_backend() - elif self.num_gpus == 0: - if self.num_nodes > 1 or self.num_processes > 1: - self.use_ddp = True # ddp_cpu - elif self.num_gpus == 1: - self.single_gpu = True - elif self.num_gpus > 1: - rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.' - ' Trainer(distributed_backend=dp) (or ddp, ddp2).' - ' Setting distributed_backend=dp for you.') - self.use_dp = True - elif distributed_backend == "dp": - # do nothing if num_gpus == 0 - if self.num_gpus == 1: - self.single_gpu = True - self.use_dp = True - elif self.num_gpus > 1: - self.use_dp = True - elif distributed_backend == "ddp": - if self.num_gpus == 0: - if self.num_nodes > 1 or self.num_processes > 1: - self.use_ddp = True # ddp_cpu - elif self.num_gpus == 1: - self.single_gpu = True - self.use_ddp = True - elif self.num_gpus > 1: - self.use_ddp = True - self.num_processes = self.num_gpus - elif distributed_backend == "ddp2": - # do nothing if num_gpus == 0 - if self.num_gpus >= 1: - self.use_ddp2 = True - elif distributed_backend == "ddp_cpu": - if self.num_gpus > 0: - rank_zero_warn('You requested one or more GPUs, but set the backend to `ddp_cpu`.' - ' Training will not use GPUs.') - self.use_ddp = True - self.data_parallel_device_ids = None - self.on_gpu = False - elif distributed_backend == 'horovod': - self._set_horovod_backend() - - # throw error to force user ddp or ddp2 choice - if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp): - raise MisconfigurationException( - 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' - 'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2' - ) - - log.info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}') - - def configure_slurm_ddp(self, num_gpu_nodes): - self.is_slurm_managing_tasks = False - - # extract SLURM flag vars - # whenever we have the correct number of tasks, we let slurm manage processes - # otherwise we launch the required number of processes - if self.use_ddp: - self.num_requested_gpus = self.num_gpus * num_gpu_nodes - self.num_slurm_tasks = 0 - try: - self.num_slurm_tasks = int(os.environ['SLURM_NTASKS']) - self.is_slurm_managing_tasks = self.num_slurm_tasks == self.num_requested_gpus - - # in interactive mode we don't manage tasks - job_name = os.environ['SLURM_JOB_NAME'] - if job_name == 'bash': - self.is_slurm_managing_tasks = False - - except Exception: - # likely not on slurm, so set the slurm managed flag to false - self.is_slurm_managing_tasks = False - - # used for tests only, set this flag to simulate slurm managing a task - try: - should_fake = int(os.environ['FAKE_SLURM_MANAGING_TASKS']) - if should_fake: - self.is_slurm_managing_tasks = True - except Exception as e: - pass - - # notify user the that slurm is managing tasks - if self.is_slurm_managing_tasks: - log.info('Multi-processing is handled by Slurm.') - - def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): - if data_parallel_device_ids is None: - return - - # set the correct cuda visible devices (using pci order) - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - - # when slurm is managing the task it sets the visible devices - if not is_slurm_managing_tasks: - if isinstance(data_parallel_device_ids, int): - id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids))) - os.environ["CUDA_VISIBLE_DEVICES"] = id_str - else: - gpu_str = ','.join([str(x) for x in data_parallel_device_ids]) - os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str - - log.debug(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') - - def ddp_train(self, process_idx, model): - """ - Entry point into a DP thread - :param gpu_idx: - :param model: - :param cluster_obj: - :return: - """ - # node rank using relative slurm id if under slurm management - # otherwise use given node rank or default to node rank 0 - try: - node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK'] - self.node_rank = int(node_id) - except KeyError: - log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.") - self.node_rank = 0 - - # show progressbar only on progress_rank 0 - if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - - # determine which process we are and world size - if self.use_ddp: - self.proc_rank = self.node_rank * self.num_processes + process_idx - self.world_size = self.num_nodes * self.num_processes - - elif self.use_ddp2: - self.proc_rank = self.node_rank - self.world_size = self.num_nodes - - # set warning rank - rank_zero_only.rank = self.proc_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self - model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - # MODEL - # copy model to each gpu - if self.on_gpu: - self.root_gpu = process_idx - torch.cuda.set_device(self.root_gpu) - model.cuda(self.root_gpu) - - # set model properties before going into wrapper - self.copy_trainer_model_properties(model) - - # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove in v0.8.0 - if self.use_amp and not self.use_native_amp: - model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) - self.optimizers = optimizers - - # DDP2 uses all GPUs on the machine - if self.distributed_backend == 'ddp': - device_ids = [self.root_gpu] - elif self.use_ddp2: - device_ids = self.data_parallel_device_ids - else: # includes ddp_cpu - device_ids = None - - # allow user to configure ddp - model = model.configure_ddp(model, device_ids) - - # continue training routine - self.run_pretrain_routine(model) - - # when ddp ends, we save the model - self.save_spawn_weights(model) - - def save_spawn_weights(self, model): - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - :param model: - :return: - """ - if self.proc_rank == 0: - path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') - self.save_checkpoint(path) - - def load_spawn_weights(self, original_model): - """ - Load the temp weights saved in the process - To recover the trained model from the ddp process we load the saved weights - :param model: - :return: - """ - - loaded_model = original_model - - if self.proc_rank == 0: - # load weights saved in ddp - path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') - loaded_model = original_model.__class__.load_from_checkpoint(path) - - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) - - # remove ddp weights - os.remove(path) - - return loaded_model - - def resolve_root_node_address(self, root_node): - if '[' in root_node: - name = root_node.split('[')[0] - number = root_node.split(',')[0] - if '-' in number: - number = number.split('-')[0] - - number = re.sub('[^0-9]', '', number) - root_node = name + number - - return root_node - - def _set_horovod_backend(self): - self.check_horovod() - self.use_horovod = True - - # Initialize Horovod to get rank / size info - hvd.init() - if self.on_gpu: - # Horovod assigns one local GPU per process - self.root_gpu = hvd.local_rank() - - def check_horovod(self): - """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" - if not HOROVOD_AVAILABLE: - raise MisconfigurationException( - 'Requested `distributed_backend="horovod"`, but Horovod is not installed.' - 'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]' - ) - - if self.num_gpus > 1 or self.num_nodes > 1: - raise MisconfigurationException( - 'Horovod does not support setting num_nodes / num_gpus explicitly. Use ' - 'horovodrun / mpirun to configure the number of processes.' - ) - - @staticmethod - def has_horovodrun(): - """Returns True if running with `horovodrun` using Gloo or OpenMPI.""" - return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py deleted file mode 100644 index bcd0c0724ee7c2..00000000000000 --- a/pytorch_lightning/trainer/distrib_parts.py +++ /dev/null @@ -1,778 +0,0 @@ -""" -Lightning makes multi-gpu training and 16 bit training trivial. - -.. note:: None of the flags below require changing anything about your lightningModel definition. - -Choosing a backend -================== - -Lightning supports two backends. DataParallel and DistributedDataParallel. - Both can be used for single-node multi-GPU training. - For multi-node training you must use DistributedDataParallel. - -DataParallel (dp) ------------------ - -Splits a batch across multiple GPUs on the same node. Cannot be used for multi-node training. - -DistributedDataParallel (ddp) ------------------------------ - -Trains a copy of the model on each GPU and only syncs gradients. If used with DistributedSampler, each GPU trains -on a subset of the full dataset. - -DistributedDataParallel-2 (ddp2) --------------------------------- - -Works like DDP, except each node trains a single copy of the model using ALL GPUs on that node. - Very useful when dealing with negative samples, etc... - -You can toggle between each mode by setting this flag. - -.. code-block:: python - - # DEFAULT (when using single GPU or no GPUs) - trainer = Trainer(distributed_backend=None) - - # Change to DataParallel (gpus > 1) - trainer = Trainer(distributed_backend='dp') - - # change to distributed data parallel (gpus > 1) - trainer = Trainer(distributed_backend='ddp') - - # change to distributed data parallel (gpus > 1) - trainer = Trainer(distributed_backend='ddp2') - -If you request multiple nodes, the back-end will auto-switch to ddp. - We recommend you use DistributedDataparallel even for single-node multi-GPU training. - It is MUCH faster than DP but *may* have configuration issues depending on your cluster. - -For a deeper understanding of what lightning is doing, feel free to read this - `guide `_. - -Distributed and 16-bit precision --------------------------------- - -Due to an issue with apex and DistributedDataParallel (PyTorch and NVIDIA issue), Lightning does - not allow 16-bit and DP training. We tried to get this to work, but it's an issue on their end. - -Below are the possible configurations we support. - -+-------+---------+----+-----+---------+------------------------------------------------------------+ -| 1 GPU | 1+ GPUs | DP | DDP | 16-bit | command | -+=======+=========+====+=====+=========+============================================================+ -| Y | | | | | `Trainer(gpus=1)` | -+-------+---------+----+-----+---------+------------------------------------------------------------+ -| Y | | | | Y | `Trainer(gpus=1, use_amp=True)` | -+-------+---------+----+-----+---------+------------------------------------------------------------+ -| | Y | Y | | | `Trainer(gpus=k, distributed_backend='dp')` | -+-------+---------+----+-----+---------+------------------------------------------------------------+ -| | Y | | Y | | `Trainer(gpus=k, distributed_backend='ddp')` | -+-------+---------+----+-----+---------+------------------------------------------------------------+ -| | Y | | Y | Y | `Trainer(gpus=k, distributed_backend='ddp', use_amp=True)` | -+-------+---------+----+-----+---------+------------------------------------------------------------+ - -You also have the option of specifying which GPUs to use by passing a list: - -.. code-block:: python - - # DEFAULT (int) specifies how many GPUs to use. - Trainer(gpus=k) - - # Above is equivalent to - Trainer(gpus=list(range(k))) - - # You specify which GPUs (don't use if running on cluster) - Trainer(gpus=[0, 1]) - - # can also be a string - Trainer(gpus='0, 1') - - # can also be -1 or '-1', this uses all available GPUs - # this is equivalent to list(range(torch.cuda.available_devices())) - Trainer(gpus=-1) - - -CUDA flags ----------- - -CUDA flags make certain GPUs visible to your script. - Lightning sets these for you automatically, there's NO NEED to do this yourself. - -.. code-block:: python - - # lightning will set according to what you give the trainer - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - -However, when using a cluster, Lightning will NOT set these flags (and you should not either). - SLURM will set these for you. - -16-bit mixed precision ----------------------- - -16 bit precision can cut your memory footprint by half. If using volta architecture GPUs - it can give a dramatic training speed-up as well. - First, install apex (if install fails, look `here `__):: - - $ git clone https://github.com/NVIDIA/apex - $ cd apex - - # ------------------------ - # OPTIONAL: on your cluster you might need to load cuda 10 or 9 - # depending on how you installed PyTorch - - # see available modules - module avail - - # load correct cuda before install - module load cuda-10.0 - # ------------------------ - - # make sure you've loaded a cuda version > 4.0 and < 7.0 - module load gcc-6.1.0 - - $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ - - -then set this use_amp to True.:: - - # DEFAULT - trainer = Trainer(amp_level='O2', use_amp=False) - - -Single-gpu ----------- - -Make sure you're on a GPU machine.:: - - # DEFAULT - trainer = Trainer(gpus=1) - -Multi-gpu ---------- - -Make sure you're on a GPU machine. You can set as many GPUs as you want. - In this setting, the model will run on all 8 GPUs at once using DataParallel under the hood. - -.. code-block:: python - - # to use DataParallel - trainer = Trainer(gpus=8, distributed_backend='dp') - - # RECOMMENDED use DistributedDataParallel - trainer = Trainer(gpus=8, distributed_backend='ddp') - -Custom device selection ------------------------ - -The number of GPUs can also be selected with a list of indices or a string containing -a comma separated list of GPU ids. -The table below lists examples of possible input formats and how they are interpreted by Lightning. -Note in particular the difference between `gpus=0`, `gpus=[0]` and `gpus="0"`. - -+---------------+-----------+---------------------+---------------------------------+ -| `gpus` | Type | Parsed | Meaning | -+===============+===========+=====================+=================================+ -| None | NoneType | None | CPU | -+---------------+-----------+---------------------+---------------------------------+ -| 0 | int | None | CPU | -+---------------+-----------+---------------------+---------------------------------+ -| 3 | int | [0, 1, 2] | first 3 GPUs | -+---------------+-----------+---------------------+---------------------------------+ -| -1 | int | [0, 1, 2, ...] | all available GPUs | -+---------------+-----------+---------------------+---------------------------------+ -| [0] | list | [0] | GPU 0 | -+---------------+-----------+---------------------+---------------------------------+ -| [1, 3] | list | [1, 3] | GPUs 1 and 3 | -+---------------+-----------+---------------------+---------------------------------+ -| "0" | str | [0] | GPU 0 | -+---------------+-----------+---------------------+---------------------------------+ -| "3" | str | [3] | GPU 3 | -+---------------+-----------+---------------------+---------------------------------+ -| "1, 3" | str | [1, 3] | GPUs 1 and 3 | -+---------------+-----------+---------------------+---------------------------------+ -| "-1" | str | [0, 1, 2, ...] | all available GPUs | -+---------------+-----------+---------------------+---------------------------------+ - - -Multi-node ----------- - -Multi-node training is easily done by specifying these flags. - -.. code-block:: python - - # train on 12*8 GPUs - trainer = Trainer(gpus=8, num_nodes=12, distributed_backend='ddp') - - -You must configure your job submission script correctly for the trainer to work. - Here is an example script for the above trainer configuration. - -.. code-block:: bash - - #!/bin/bash -l - - # SLURM SUBMIT SCRIPT - #SBATCH --nodes=12 - #SBATCH --gres=gpu:8 - #SBATCH --ntasks-per-node=8 - #SBATCH --mem=0 - #SBATCH --time=0-02:00:00 - - # activate conda env - conda activate my_env - - # ------------------------- - # OPTIONAL - # ------------------------- - # debugging flags (optional) - # export NCCL_DEBUG=INFO - # export PYTHONFAULTHANDLER=1 - - # PyTorch comes with prebuilt NCCL support... but if you have issues with it - # you might need to load the latest version from your modules - # module load NCCL/2.4.7-1-cuda.10.0 - - # on your cluster you might need these: - # set the network interface - # export NCCL_SOCKET_IFNAME=^docker0,lo - # ------------------------- - - # random port between 12k and 20k - export MASTER_PORT=$((12000 + RANDOM % 20000)) - - # run script from above - python my_main_file.py - -.. note:: When running in DDP mode, any errors in your code will show up as an NCCL issue. - Set the `NCCL_DEBUG=INFO` flag to see the ACTUAL error. - -Normally now you would need to add a distributed sampler to your dataset, however -Lightning automates this for you. But if you still need to set a sampler Lightning will -not interfere nor automate it. - -Here's an example of how to add your own sampler (again no need with Lightning). - -.. code-block:: python - - # ie: this: - dataset = myDataset() - dataloader = Dataloader(dataset) - - # becomes: - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - -Auto-slurm-job-submission -------------------------- - -Instead of manually building SLURM scripts, you can use the -`SlurmCluster object `_ -to do this for you. The SlurmCluster can also run a grid search if you pass -in a `HyperOptArgumentParser -`_. - -Here is an example where you run a grid search of 9 combinations of hyperparams. -The full examples are -`here `__. - -.. code-block:: python - - # grid search 3 values of learning rate and 3 values of number of layers for your net - # this generates 9 experiments (lr=1e-3, layers=16), (lr=1e-3, layers=32), - # (lr=1e-3, layers=64), ... (lr=1e-1, layers=64) - parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) - parser.opt_list('--learning_rate', default=0.001, type=float, - options=[1e-3, 1e-2, 1e-1], tunable=True) - parser.opt_list('--layers', default=1, type=float, options=[16, 32, 64], tunable=True) - hyperparams = parser.parse_args() - - # Slurm cluster submits 9 jobs, each with a set of hyperparams - cluster = SlurmCluster( - hyperparam_optimizer=hyperparams, - log_path='/some/path/to/save', - ) - - # OPTIONAL FLAGS WHICH MAY BE CLUSTER DEPENDENT - # which interface your nodes use for communication - cluster.add_command('export NCCL_SOCKET_IFNAME=^docker0,lo') - - # see output of the NCCL connection process - # NCCL is how the nodes talk to each other - cluster.add_command('export NCCL_DEBUG=INFO') - - # setting a master port here is a good idea. - cluster.add_command('export MASTER_PORT=%r' % PORT) - - # ************** DON'T FORGET THIS *************** - # MUST load the latest NCCL version - cluster.load_modules(['NCCL/2.4.7-1-cuda.10.0']) - - # configure cluster - cluster.per_experiment_nb_nodes = 12 - cluster.per_experiment_nb_gpus = 8 - - cluster.add_slurm_cmd(cmd='ntasks-per-node', value=8, comment='1 task per gpu') - - # submit a script with 9 combinations of hyper params - # (lr=1e-3, layers=16), (lr=1e-3, layers=32), (lr=1e-3, layers=64), ... (lr=1e-1, layers=64) - cluster.optimize_parallel_cluster_gpu( - main, - nb_trials=9, # how many permutations of the grid search to run - job_name='name_for_squeue' - ) - - -The other option is that you generate scripts on your own via a bash command or use another library... - -Self-balancing architecture ---------------------------- - -Here lightning distributes parts of your module across available GPUs to optimize for speed and memory. - -""" - -from contextlib import ExitStack -import os -from abc import ABC, abstractmethod -import time -import random -import torch -from typing import Union - -from pytorch_lightning import _logger as log -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.overrides.data_parallel import ( - LightningDistributedDataParallel, - LightningDataParallel, -) -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.distributed import rank_zero_only - -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True - -try: - import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - - -class TrainerDPMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - on_gpu: bool - use_dp: bool - use_ddp2: bool - use_ddp: bool - testing: bool - single_gpu: bool - root_gpu: ... - amp_level: str - precision: ... - current_tpu_idx: ... - proc_rank: int - tpu_local_core_rank: int - tpu_global_core_rank: int - use_tpu: bool - use_native_amp: bool - data_parallel_device_ids: ... - logger: Union[LightningLoggerBase, bool] - progress_bar_callback: ... - - @property - @abstractmethod - def use_amp(self) -> bool: - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def run_pretrain_routine(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def init_optimizers(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def copy_trainer_model_properties(self, model): - if isinstance(model, LightningDataParallel): - ref_model = model.module - elif isinstance(model, LightningDistributedDataParallel): - ref_model = model.module - else: - ref_model = model - - for m in [model, ref_model]: - m.trainer = self - m.on_gpu = self.on_gpu - m.use_dp = self.use_dp - m.use_ddp2 = self.use_ddp2 - m.use_ddp = self.use_ddp - m.use_amp = self.use_amp - m.testing = self.testing - m.single_gpu = self.single_gpu - m.use_tpu = self.use_tpu - m.tpu_local_core_rank = self.tpu_local_core_rank - m.tpu_global_core_rank = self.tpu_global_core_rank - - def transfer_batch_to_tpu(self, batch): - return self.__transfer_data_to_device(batch, device='tpu') - - def transfer_batch_to_gpu(self, batch, gpu_id): - return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id) - - def __transfer_data_to_device(self, batch, device, gpu_id=None): - if device == 'tpu' and XLA_AVAILABLE: - # base case: object can be directly moved using `to` - if callable(getattr(batch, 'to', None)): - return batch.to(xm.xla_device()) - - if device == 'gpu': - # base case: object can be directly moved using `cuda` or `to` - if callable(getattr(batch, 'cuda', None)): - return batch.cuda(gpu_id) - - if callable(getattr(batch, 'to', None)): - return batch.to(torch.device('cuda', gpu_id)) - - # when list - if isinstance(batch, list): - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return batch - - # when tuple - if isinstance(batch, tuple): - # when namedtuple - if hasattr(batch, '_fields'): - elem_type = type(batch) - return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch)) - else: - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return tuple(batch) - - # when dict - if isinstance(batch, dict): - for k, v in batch.items(): - batch[k] = self.__transfer_data_to_device(v, device, gpu_id) - - return batch - - # nothing matches, return the value as is without transform - return batch - - def single_gpu_train(self, model): - model.cuda(self.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - # TODO: update for 0.8.0 - if self.use_amp and not self.use_native_amp: - # An example - model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) - self.optimizers = optimizers - - self.run_pretrain_routine(model) - - def tpu_train(self, tpu_core_idx, model): - # put model on tpu - model.to(xm.xla_device()) - - # get the appropriate tpu ranks - self.tpu_local_core_rank = xm.get_local_ordinal() - self.tpu_global_core_rank = xm.get_ordinal() - - # avoid duplicating progress bar - if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - - # track current tpu - self.current_tpu_idx = tpu_core_idx - self.proc_rank = self.tpu_local_core_rank - rank_zero_only.rank = self.proc_rank - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - # init 16 bit for TPU - if self.precision == 16: - os.environ['XLA_USE_BF16'] = str(1) - - log.info(f'INIT TPU local core: {self.tpu_local_core_rank},' - f' global rank: {self.tpu_global_core_rank}') - - # continue training routine - self.run_pretrain_routine(model) - - self.save_spawn_weights(model) - - def dp_train(self, model): - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - model.cuda(self.root_gpu) - - # hack forward to do autocast for the user - model_autocast_original_forward = model.forward - if self.use_amp and self.use_native_amp: - # wrap the user's forward in autocast and give it back at the end - model.forward = torch.cuda.amp.autocast()(model.forward) - - # TODO: remove in v0.8.0 - # check for this bug (amp + dp + !01 doesn't work) - # https://github.com/NVIDIA/apex/issues/227 - if self.use_dp and self.use_amp and not self.use_native_amp: - if self.amp_level == 'O2': - raise MisconfigurationException( - f'Amp level {self.amp_level} with DataParallel is not supported.' - f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' - f' We recommend you switch to ddp if you want to use amp') - else: - model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) - - # create list of device ids - device_ids = self.data_parallel_device_ids - if isinstance(device_ids, int): - device_ids = list(range(device_ids)) - - # set dp device - torch.cuda.set_device(self.root_gpu) - - model = LightningDataParallel(model, device_ids=device_ids) - - self.run_pretrain_routine(model) - - model.forward = model_autocast_original_forward - - def horovod_train(self, model): - if torch.cuda.is_available() and self.on_gpu: - # Horovod: pin GPU to local rank - assert self.root_gpu == hvd.local_rank() - torch.cuda.set_device(self.root_gpu) - model.cuda(self.root_gpu) - - # avoid duplicating progress bar - if hvd.rank() != 0 and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - # Horovod: scale the learning rate by the number of workers to account for - # increased total batch size - for optimizer in self.optimizers: - for param_group in optimizer.param_groups: - param_group['lr'] *= hvd.size() - - if self.use_amp: - # An example - model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) - self.optimizers = optimizers - - # Horovod: broadcast parameters & optimizer state to ensure consistent initialization - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - for optimizer in self.optimizers: - hvd.broadcast_optimizer_state(optimizer, root_rank=0) - - def filter_named_parameters(model, optimizer): - opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])]) - return [(name, p) for name, p in model.named_parameters() if p in opt_params] - - # Horovod: wrap optimizers to perform gradient aggregation via allreduce - self.optimizers = [ - hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer)) - for optimizer in self.optimizers - ] - - # Update logger rank info from Horovod to avoid race conditions from different ranks - # creating directories / writing files in the same locations. - self.proc_rank = hvd.rank() - rank_zero_only.rank = self.proc_rank - - with ExitStack() as stack: - for optimizer in self.optimizers: - # Synchronization will be performed explicitly following backward() - stack.enter_context(optimizer.skip_synchronize()) - - self.run_pretrain_routine(model) - - -def normalize_parse_gpu_string_input(s): - if isinstance(s, str): - if s == '-1': - return -1 - else: - return [int(x.strip()) for x in s.split(',') if len(x) > 0] - else: - return s - - -def get_all_available_gpus(): - """ - :return: a list of all available gpus - """ - return list(range(torch.cuda.device_count())) - - -def check_gpus_data_type(gpus): - """ - :param gpus: gpus parameter as passed to the Trainer - Function checks that it is one of: None, Int, String or List - Throws otherwise - :return: return unmodified gpus variable - """ - - if gpus is not None and (not isinstance(gpus, (int, str, list)) or isinstance(gpus, bool)): - raise MisconfigurationException("GPUs must be int, string or list of ints or None.") - - -def normalize_parse_gpu_input_to_list(gpus): - assert gpus is not None - if isinstance(gpus, list): - return gpus - - # must be an int - if not gpus: # gpus==0 - return None - if gpus == -1: - return get_all_available_gpus() - - return list(range(gpus)) - - -def sanitize_gpu_ids(gpus): - """ - :param gpus: list of ints corresponding to GPU indices - Checks that each of the GPUs in the list is actually available. - Throws if any of the GPUs is not available. - :return: unmodified gpus variable - """ - all_available_gpus = get_all_available_gpus() - for gpu in gpus: - if gpu not in all_available_gpus: - raise MisconfigurationException(f""" - You requested GPUs: {gpus} - But your machine only has: {all_available_gpus} - """) - return gpus - - -def parse_gpu_ids(gpus): - """ - :param gpus: Int, string or list - An int -1 or string '-1' indicate that all available GPUs should be used. - A list of ints or a string containing list of comma separated integers - indicates specific GPUs to use - An int 0 means that no GPUs should be used - Any int N > 0 indicates that GPUs [0..N) should be used. - :return: List of gpus to be used - - If no GPUs are available but the value of gpus variable indicates request for GPUs - then a misconfiguration exception is raised. - """ - - # nothing was passed into the GPUs argument - if callable(gpus): - return None - - # Check that gpus param is None, Int, String or List - check_gpus_data_type(gpus) - - # Handle the case when no gpus are requested - if gpus is None or isinstance(gpus, int) and gpus == 0: - return None - - # We know user requested GPUs therefore if some of the - # requested GPUs are not available an exception is thrown. - - gpus = normalize_parse_gpu_string_input(gpus) - gpus = normalize_parse_gpu_input_to_list(gpus) - gpus = sanitize_gpu_ids(gpus) - - if not gpus: - raise MisconfigurationException("GPUs requested but none are available.") - return gpus - - -def determine_root_gpu_device(gpus): - """ - :param gpus: non empty list of ints representing which gpus to use - :return: designated root GPU device - """ - if gpus is None: - return None - - assert isinstance(gpus, list), "gpus should be a list" - assert len(gpus) > 0, "gpus should be a non empty list" - - # set root gpu - root_gpu = gpus[0] - - return root_gpu - - -def retry_jittered_backoff(f, num_retries=5): - # Based on: - # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ - cap = 1.0 # max sleep time is 1s - base = 0.01 # initial sleep time is 10ms - sleep = base # initial sleep time is 10ms - - for i in range(num_retries): - try: - return f() - except RuntimeError as e: - if i == num_retries - 1: - raise e - else: - continue - time.sleep(sleep) - sleep = min(cap, random.uniform(base, sleep * 3)) - - -def pick_single_gpu(exclude_gpus=[]): - for i in range(torch.cuda.device_count()): - if i in exclude_gpus: - continue - # Try to allocate on device: - device = torch.device(f"cuda:{i}") - try: - torch.ones(1).to(device) - except RuntimeError: - continue - return i - raise RuntimeError("No GPUs available.") - - -def pick_multiple_gpus(n): - picked = [] - for _ in range(n): - picked.append(pick_single_gpu(exclude_gpus=picked)) - - return picked diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0320bf35419eab..8b7543e6bf50f6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -1,449 +1,371 @@ -""" -Validation loop -=============== +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -The lightning validation loop handles everything except the actual computations of your model. -To decide what will happen in your validation loop, define the `validation_step` function. -Below are all the things lightning automates for you in the validation loop. - -.. note:: Lightning will run 5 steps of validation in the beginning of training as a sanity - check so you don't have to wait until a full epoch to catch possible validation issues. - -Check validation every n epochs -------------------------------- - -If you have a small dataset you might want to check validation every n epochs - -.. code-block:: python - - # DEFAULT - trainer = Trainer(check_val_every_n_epoch=1) - -Set how much of the validation set to check -------------------------------------------- - -If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag +import torch -val_percent_check will be overwritten by overfit_pct if `overfit_pct > 0` +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache -.. code-block:: python - # DEFAULT - trainer = Trainer(val_percent_check=1.0) +class EvaluationLoop(object): - # check 10% only - trainer = Trainer(val_percent_check=0.1) + def __init__(self, trainer): + self.trainer = trainer + self.outputs = [] + self.step_metrics = [] + self.predictions = None + self.max_batches = None + self.warning_cache = WarningCache() + self.num_dataloaders = None -Set how much of the test set to check -------------------------------------- + def on_trainer_init(self): + self.trainer.num_sanity_val_batches = [] + self.trainer.num_test_batches = [] + self.trainer.num_val_batches = [] + self.trainer.test_dataloaders = None + self.trainer.val_dataloaders = None -If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag + # .validate() and .test() set this when they load a checkpoint + self.trainer.validated_ckpt_path = None + self.trainer.tested_ckpt_path = None -test_percent_check will be overwritten by overfit_pct if `overfit_pct > 0` + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True -.. code-block:: python + def get_evaluation_dataloaders(self): + model = self.trainer.lightning_module - # DEFAULT - trainer = Trainer(test_percent_check=1.0) + # select dataloaders + if self.trainer.testing: + self.trainer.reset_test_dataloader(model) - # check 10% only - trainer = Trainer(test_percent_check=0.1) + dataloaders = self.trainer.test_dataloaders + max_batches = self.trainer.num_test_batches + else: + # val + if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_val_dataloader(model) + if self.trainer.sanity_checking: + self.trainer.num_sanity_val_batches = [ + min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches + ] + max_batches = self.trainer.num_sanity_val_batches + else: + max_batches = self.trainer.num_val_batches + dataloaders = self.trainer.val_dataloaders + return dataloaders, max_batches + + def should_skip_evaluation(self, max_batches): + return sum(max_batches) == 0 + + def on_evaluation_start(self, *args, **kwargs): + if self.trainer.testing: + self.trainer.call_hook('on_test_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_start', *args, **kwargs) -Set validation check frequency within 1 training epoch ------------------------------------------------------- + def on_evaluation_model_eval(self, *_, **__): + model_ref = self.trainer.lightning_module + if self.trainer.testing: + model_ref.on_test_model_eval() + else: + model_ref.on_validation_model_eval() -For large datasets it's often desirable to check validation multiple times within a training loop. - Pass in a float to check that often within 1 training epoch. - Pass in an int k to check every k training batches. Must use an int if using an IterableDataset. + def on_evaluation_model_train(self, *_, **__): + model_ref = self.trainer.lightning_module + if self.trainer.testing: + model_ref.on_test_model_train() + else: + model_ref.on_validation_model_train() -.. code-block:: python + def on_evaluation_end(self, *args, **kwargs): + if self.trainer.testing: + self.trainer.call_hook('on_test_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_end', *args, **kwargs) - # DEFAULT - trainer = Trainer(val_check_interval=0.95) + if self.trainer.state != TrainerState.FITTING: + # summarize profile results + self.trainer.profiler.describe() - # check every .25 of an epoch - trainer = Trainer(val_check_interval=0.25) + def reload_evaluation_dataloaders(self): + model = self.trainer.lightning_module + if self.trainer.testing: + self.trainer.reset_test_dataloader(model) + else: + self.trainer.reset_val_dataloader(model) - # check every 100 train batches (ie: for IterableDatasets or fixed frequency) - trainer = Trainer(val_check_interval=100) + def setup(self, model, max_batches, dataloaders): + # bookkeeping + self.outputs = [] + self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) -Set the number of validation sanity steps ------------------------------------------ + self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) + self._predictions = [[] for _ in range(self.num_dataloaders)] -Lightning runs a few steps of validation in the beginning of training. - This avoids crashing in the validation loop sometime deep into a lengthy training loop. + def on_evaluation_epoch_start(self, *args, **kwargs): + self.trainer.call_hook('on_epoch_start', *args, **kwargs) -.. code-block:: python + if self.trainer.testing: + self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) + def _build_args(self, batch, batch_idx, dataloader_idx): + # make dataloader_idx arg in validation_step optional + args = [batch, batch_idx] + multiple_val_loaders = ( + not self.trainer.testing and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1 + ) + multiple_test_loaders = (self.trainer.testing and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) -You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check. + if multiple_test_loaders or multiple_val_loaders: + args.append(dataloader_idx) -# Testing loop + return args -To ensure you don't accidentally use test data to guide training decisions Lightning - makes running the test set deliberate. + def _get_num_dataloaders(self, dataloaders): + # case where user does: + # return dl1, dl2 + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length -**test** + def evaluation_step(self, batch, batch_idx, dataloader_idx): + # configure args + args = self._build_args(batch, batch_idx, dataloader_idx) -You have two options to run the test set. -First case is where you test right after a full training routine. + model_ref = self.trainer.lightning_module + model_ref._results = Result() -.. code-block:: python + if self.trainer.testing: + model_ref._current_fx_name = "test_step" + with self.trainer.profiler.profile("test_step"): + output = self.trainer.accelerator.test_step(args) + else: + model_ref._current_fx_name = "validation_step" + with self.trainer.profiler.profile("validation_step"): + output = self.trainer.accelerator.validation_step(args) - # run full training - trainer.fit(model) + # capture any logged information + self.trainer.logger_connector.cache_logged_metrics() + # track batch size for weighted average + is_result_obj = isinstance(output, Result) + if is_result_obj: + output.track_batch_size(batch) - # run test set - trainer.test() + return output + def evaluation_step_end(self, *args, **kwargs): + if self.trainer.testing: + output = self.trainer.call_hook('test_step_end', *args, **kwargs) + else: + output = self.trainer.call_hook('validation_step_end', *args, **kwargs) + return output -Second case is where you load a model and run the test set + def evaluation_epoch_end(self): + # unset dataloder_idx in model + self.trainer.logger_connector.evaluation_epoch_end() -.. code-block:: python + # call the model epoch end + deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders) - model = MyLightningModule.load_from_metrics( - weights_path='/path/to/pytorch_checkpoint.ckpt', - tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv', - on_gpu=True, - map_location=None - ) + # enable returning anything + for i, r in enumerate(deprecated_results): + if not isinstance(r, (dict, Result, torch.Tensor)): + deprecated_results[i] = [] - # init trainer with whatever options - trainer = Trainer(...) + return deprecated_results - # test (pass in the model) - trainer.test(model) + def log_epoch_metrics_on_evaluation_end(self): + # get the final loop results + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() + return eval_loop_results -In this second case, the options you pass to trainer will be used when running - the test set (ie: 16-bit, dp, ddp, etc...) + def __run_eval_epoch_end(self, num_dataloaders): + model = self.trainer.lightning_module -""" + # with a single dataloader don't pass an array + outputs = self.outputs -from abc import ABC, abstractmethod -from pprint import pprint -from typing import Callable + eval_results = outputs + if num_dataloaders == 1: + eval_results = outputs[0] -import torch -from torch.utils.data import DataLoader - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_warn - -try: - import torch_xla.distributed.parallel_loader as xla_pl - import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - - -class TrainerEvaluationLoopMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - on_gpu: bool - use_ddp: bool - use_dp: bool - use_ddp2: bool - use_horovod: bool - single_gpu: bool - data_parallel_device_ids: ... - model: LightningModule - num_test_batches: int - num_val_batches: int - fast_dev_run: ... - process_output: ... - progress_bar_dict: ... - proc_rank: int - current_epoch: int - callback_metrics: ... - test_dataloaders: DataLoader - val_dataloaders: DataLoader - use_tpu: bool - reload_dataloaders_every_epoch: ... - - # Callback system - on_validation_batch_start: Callable - on_validation_batch_end: Callable - on_test_batch_start: Callable - on_test_batch_end: Callable - on_validation_start: Callable - on_validation_end: Callable - on_test_start: Callable - on_test_end: Callable - - @abstractmethod - def copy_trainer_model_properties(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def get_model(self): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def is_overriden(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def transfer_batch_to_tpu(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def transfer_batch_to_gpu(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def add_progress_bar_metrics(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def log_metrics(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def reset_test_dataloader(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def reset_val_dataloader(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False): - """Run evaluation code. - - Args: - model: PT model - dataloaders: list of PT dataloaders - max_batches: Scalar - test_mode: - """ - # enable eval mode - model.zero_grad() - model.eval() - - # copy properties for forward overrides - self.copy_trainer_model_properties(model) - - # disable gradients to save memory - torch.set_grad_enabled(False) + user_reduced = False - # bookkeeping - outputs = [] + if self.trainer.testing: + if is_overridden('test_epoch_end', model=model): + model._current_fx_name = 'test_epoch_end' + eval_results = model.test_epoch_end(eval_results) + user_reduced = True - # run validation - for dataloader_idx, dataloader in enumerate(dataloaders): - dl_outputs = [] + else: + if is_overridden('validation_epoch_end', model=model): + model._current_fx_name = 'validation_epoch_end' + eval_results = model.validation_epoch_end(eval_results) + user_reduced = True + + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + # depre warning + if eval_results is not None and user_reduced: + step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end' + self.warning_cache.warn( + f'The {step} should not return anything as of 9.1.' + ' To log, use self.log(...) or self.write(...) directly in the LightningModule' + ) + + if not isinstance(eval_results, list): + eval_results = [eval_results] + + self.trainer.logger_connector._track_callback_metrics(eval_results) - # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu: - device = xm.xla_device() - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) + return eval_results - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue + def __gather_epoch_end_eval_results(self, outputs): + eval_results = [] + for epoch_output in outputs: + result = epoch_output[0].__class__.gather(epoch_output) + eval_results.append(result) - # stop short when on fast_dev_run (sets max_batch=1) - if batch_idx >= max_batches: - break + # with 1 dataloader don't pass in a list + if len(eval_results) == 1: + eval_results = eval_results[0] + return eval_results - # callbacks - if test_mode: - self.on_test_batch_start() - else: - self.on_validation_batch_start() - - # ----------------- - # RUN EVALUATION STEP - # ----------------- - if self.use_amp and self.use_native_amp: - with torch.cuda.amp.autocast(): - output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) - else: - output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) - - # on dp / ddp2 might still want to do something with the batch parts - if test_mode: - if self.is_overriden('test_step_end'): - model_ref = self.get_model() - with self.profiler.profile('test_step_end'): - output = model_ref.test_step_end(output) - self.on_test_batch_end() - else: - if self.is_overriden('validation_step_end'): - model_ref = self.get_model() - with self.profiler.profile('validation_step_end'): - output = model_ref.validation_step_end(output) - self.on_validation_batch_end() + def __auto_reduce_result_objs(self, outputs): + # outputs has a list of results per dataloader + eval_results = [] + for dl_output in outputs: + result = dl_output[0] + result = result.__class__.reduce_on_epoch_end(dl_output) + eval_results.append(result) - # track outputs for collation - dl_outputs.append(output) + return eval_results - outputs.append(dl_outputs) + def on_predict_epoch_end(self): + self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.lightning_module) - eval_results = {} + results = self._predictions - # with a single dataloader don't pass an array - if len(dataloaders) == 1: - outputs = outputs[0] + def _convert_to_numpy(v): + return v.cpu().numpy() - # give model a chance to do something with the outputs (and method defined) - if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): - model = model.module + results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) - if test_mode: - if self.is_overriden('test_end', model=model): - # TODO: remove in v1.0.0 - eval_results = model.test_end(outputs) - rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.' - ' Use `test_epoch_end` instead.', DeprecationWarning) + return results, None - elif self.is_overriden('test_epoch_end', model=model): - eval_results = model.test_epoch_end(outputs) + def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): + # set dataloader_idx to model and track batch_size + self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders) + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: - if self.is_overriden('validation_end', model=model): - # TODO: remove in v1.0.0 - eval_results = model.validation_end(outputs) - rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.' - ' Use `validation_epoch_end` instead.', DeprecationWarning) + self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) - elif self.is_overriden('validation_epoch_end', model=model): - eval_results = model.validation_epoch_end(outputs) - - # enable train mode again - model.train() + def on_evaluation_batch_end(self, output, batch, batch_idx, dataloader_idx): + if self.trainer.testing: + self.trainer.call_hook('on_test_batch_end', output, batch, batch_idx, dataloader_idx) + else: + self.trainer.call_hook('on_validation_batch_end', output, batch, batch_idx, dataloader_idx) - # enable gradients to save memory - torch.set_grad_enabled(True) + # store predicitons if do_write_predictions and track eval loss history + self.store_predictions(output, batch_idx, dataloader_idx) - return eval_results + def store_predictions(self, output, batch_idx, dataloader_idx): + # Add step predictions to prediction collection to write later + if output is not None: + do_write_predictions = isinstance(output, Result) and self.trainer.testing + if do_write_predictions: + self.predictions.add(output.pop('predictions', None)) - def run_evaluation(self, test_mode: bool = False): - # when testing make sure user defined a test step - if test_mode and not self.is_overriden('test_step'): - raise MisconfigurationException( - "You called `.test()` without defining model's `.test_step()`." - " Please define and try again") + # track debug metrics + self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - # hook - model = self.get_model() - model.on_pre_performance_check() + def on_evaluation_epoch_end(self, *args, **kwargs): + # call the callback hook + self.call_on_evaluation_epoch_end_hook() - # select dataloaders - if test_mode: - if self.test_dataloaders is None: - self.reset_test_dataloader(model) + self.trainer.call_hook('on_epoch_end') - dataloaders = self.test_dataloaders - max_batches = self.num_test_batches - else: - # val - if self.val_dataloaders is None: - self.reset_val_dataloader(model) + def call_on_evaluation_epoch_end_hook(self): + outputs = self.outputs - dataloaders = self.val_dataloaders - max_batches = self.num_val_batches + # free memory + self.outputs = [] - # cap max batches to 1 when using fast_dev_run - if self.fast_dev_run: - max_batches = 1 + model_ref = self.trainer.lightning_module + hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" - # Validation/Test begin callbacks - if test_mode: - self.on_test_start() - else: - self.on_validation_start() + self.trainer._reset_result_and_set_hook_fx_name(hook_name) - # run evaluation - eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) - _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) + with self.trainer.profiler.profile(hook_name): - # add metrics to prog bar - self.add_progress_bar_metrics(prog_bar_metrics) + if hasattr(self.trainer, hook_name): + on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) + on_evaluation_epoch_end_hook(outputs) - # log results of test - if test_mode and self.proc_rank == 0: - print('-' * 80) - print('TEST RESULTS') - pprint(callback_metrics) - print('-' * 80) + if is_overridden(hook_name, model_ref): + model_hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(model_hook_fx, "outputs"): + model_hook_fx(outputs) + else: + self.warning_cache.warn( + f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_hook_fx() - # log metrics - self.log_metrics(log_metrics, {}) + self.trainer._cache_logged_metrics() - # track metrics for callbacks - self.callback_metrics.update(callback_metrics) + def log_evaluation_step_metrics(self, output, batch_idx): + if self.trainer.sanity_checking: + return - # hook - model.on_post_performance_check() + step_log_metrics = {} + step_pbar_metrics = {} - # eventual dataset reloading - if test_mode: - if self.reload_dataloaders_every_epoch: - self.reset_test_dataloader(model) - else: - # val - if self.reload_dataloaders_every_epoch: - self.reset_val_dataloader(model) + self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx) - # Validation/Test end callbacks - if test_mode: - self.on_test_end() - else: - self.on_validation_end() + def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx): + cached_results = self.trainer.logger_connector.cached_results + cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector() - def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): - # make dataloader_idx arg in validation_step optional - args = [batch, batch_idx] + step_log_metrics.update(cached_batch_log_metrics) + step_pbar_metrics.update(cached_batch_pbar_metrics) - if (test_mode and len(self.test_dataloaders) > 1) \ - or (not test_mode and len(self.val_dataloaders) > 1): - args.append(dataloader_idx) + if len(step_log_metrics) > 0: + # make the metrics appear as a different line in the same graph + metrics_by_epoch = {} + for k, v in step_log_metrics.items(): + metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v - # handle DP, DDP forward - if self.use_ddp or self.use_dp or self.use_ddp2: - output = model(*args) - return output - - # Horovod - if self.use_horovod and self.on_gpu: - batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) - args[0] = batch - - # single GPU data transfer - if self.single_gpu: - # for single GPU put inputs on gpu manually - root_gpu = 0 - if isinstance(self.data_parallel_device_ids, list): - root_gpu = self.data_parallel_device_ids[0] - batch = self.transfer_batch_to_gpu(batch, root_gpu) - args[0] = batch - - # TPU data transfer - if self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) - args[0] = batch - - # CPU, TPU or gpu step - if test_mode: - output = model.test_step(*args) - else: - output = model.validation_step(*args) + self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) - return output + if len(step_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/ignored_warnings.py b/pytorch_lightning/trainer/ignored_warnings.py index 9260720ec350fb..894416d607a3e0 100644 --- a/pytorch_lightning/trainer/ignored_warnings.py +++ b/pytorch_lightning/trainer/ignored_warnings.py @@ -1,14 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import warnings def ignore_scalar_return_in_dp(): # Users get confused by this warning so we silence it - m_1 = """ - Was asked to gather along dimension 0, but all - input tensors were scalars; will instead unsqueeze - and return a vector. - """ - warnings.filterwarnings('ignore', message=m_1) + warnings.filterwarnings( + 'ignore', + message='Was asked to gather along dimension 0, but all input tensors were scalars;' + ' will instead unsqueeze and return a vector.' + ) ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 978ac5df78d816..8aaac0a6591520 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,10 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect from abc import ABC -from typing import Union, Iterable +from collections import Mapping import torch -from pytorch_lightning.core import memory -from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection +from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -12,78 +28,19 @@ class TrainerLoggingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - current_epoch: int - on_gpu: bool - log_gpu_memory: ... - logger: Union[LightningLoggerBase, bool] - progress_bar_metrics: ... - global_step: int - proc_rank: int - use_dp: bool - use_ddp2: bool - default_root_dir: str - slurm_job_id: int + _distrib_type: DistributedType num_gpus: int - def configure_logger(self, logger): - if logger is True: - # default logger - self.logger = TensorBoardLogger( - save_dir=self.default_root_dir, - version=self.slurm_job_id, - name='lightning_logs' - ) - elif logger is False: - self.logger = None - else: - if isinstance(logger, Iterable): - self.logger = LoggerCollection(logger) - else: - self.logger = logger - - def log_metrics(self, metrics, grad_norm_dic, step=None): - """Logs the metric dict passed in. - If `step` parameter is None and `step` key is presented is metrics, - uses metrics["step"] as a step - - Args: - metrics (dict): Metric values - grad_norm_dic (dict): Gradient norms - step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` - """ - # add gpu memory - if self.on_gpu and self.log_gpu_memory: - mem_map = memory.get_memory_profile(self.log_gpu_memory) - metrics.update(mem_map) - - # add norms - metrics.update(grad_norm_dic) - - # turn all tensors to scalars - scalar_metrics = self.metrics_to_scalars(metrics) - - if "step" in scalar_metrics and step is None: - step = scalar_metrics.pop("step") - else: - # added metrics by Lightning for convenience - scalar_metrics['epoch'] = self.current_epoch - step = step if step is not None else self.global_step - # log actual metrics - if self.proc_rank == 0 and self.logger is not None: - self.logger.agg_and_log_metrics(scalar_metrics, step=step) - self.logger.save() - - def add_progress_bar_metrics(self, metrics): - for k, v in metrics.items(): - if isinstance(v, torch.Tensor): - v = v.item() - - self.progress_bar_metrics[k] = v - def metrics_to_scalars(self, metrics): new_metrics = {} + # TODO: this is duplicated in MetricsHolder. should be unified for k, v in metrics.items(): if isinstance(v, torch.Tensor): + if v.numel() != 1: + raise MisconfigurationException( + f"The metric `{k}` does not contain a single element" + f" thus it cannot be converted to float. Found `{v}`" + ) v = v.item() if isinstance(v, dict): @@ -93,23 +50,33 @@ def metrics_to_scalars(self, metrics): return new_metrics - def process_output(self, output, train=False): + def process_dict_result(self, output, train=False): """Reduces output according to the training mode. Separates loss from logging and progress bar metrics """ - # --------------- - # EXTRACT CALLBACK KEYS - # --------------- - # all keys not progress_bar or log are candidates for callbacks - callback_metrics = {} - for k, v in output.items(): - if k not in ['progress_bar', 'log', 'hiddens']: - callback_metrics[k] = v + # -------------------- + # WARN DEPRECATED KEYS + # -------------------- + # TODO: 1.0.0 remove + if isinstance(output, dict): + for k, v in output.items(): + if k in ['log', 'progress_bar']: + m = inspect.cleandoc( + f"The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n" + " Please use self.log(...) inside the lightningModule instead.\n" + " # log on a step or aggregate epoch metric to the logger and/or progress bar" + " (inside LightningModule)\n" + " self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)" + ) + rank_zero_warn(m) - if train and (self.use_dp or self.use_ddp2): - num_gpus = self.num_gpus - callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) + # -------------------------- + # handle single scalar only + # -------------------------- + # single scalar returned from a xx_step + if isinstance(output, torch.Tensor): + return output, {}, {}, None # --------------- # EXTRACT PROGRESS BAR KEYS @@ -118,11 +85,12 @@ def process_output(self, output, train=False): progress_output = output['progress_bar'] # reduce progress metrics for progress bar when using dp - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output(progress_output, num_gpus) progress_bar_metrics = progress_output + # todo: specify the possible exception except Exception: progress_bar_metrics = {} @@ -134,11 +102,12 @@ def process_output(self, output, train=False): log_output = output['log'] # reduce progress metrics for progress bar when using dp - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output(log_output, num_gpus) log_metrics = log_output + # todo: specify the possible exception except Exception: log_metrics = {} @@ -151,32 +120,32 @@ def process_output(self, output, train=False): if train: try: loss = output['loss'] - except Exception: + # todo: specify the possible exception + except Exception as exp: if isinstance(output, torch.Tensor): loss = output else: raise RuntimeError( 'No `loss` value in the dictionary returned from `model.training_step()`.' - ) + ) from exp # when using dp need to reduce the loss - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): loss = self.reduce_distributed_output(loss, self.num_gpus) # --------------- # EXTRACT HIDDEN # --------------- - hiddens = output.get('hiddens') - - # use every metric passed in as a candidate for callback - callback_metrics.update(progress_bar_metrics) - callback_metrics.update(log_metrics) + hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None + if hiddens is not None: + hiddens = hiddens.detach() # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down - callback_metrics = recursive_detach(callback_metrics) + progress_bar_metrics = recursive_detach(progress_bar_metrics) + log_metrics = recursive_detach(log_metrics) - return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens + return loss, progress_bar_metrics, log_metrics, hiddens def reduce_distributed_output(self, output, num_gpus): if num_gpus <= 1: @@ -192,6 +161,10 @@ def reduce_distributed_output(self, output, num_gpus): if isinstance(output[k], dict): output[k] = self.reduce_distributed_output(output[k], num_gpus) + # compute the average of scalars + elif isinstance(output[k], list): + output[k] = sum(output[k]) / len(output[k]) + # do nothing when there's a scalar elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: pass diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py deleted file mode 100755 index 5e8c7a862a65e8..00000000000000 --- a/pytorch_lightning/trainer/lr_finder.py +++ /dev/null @@ -1,459 +0,0 @@ -""" -Trainer Learning Rate Finder -""" -from abc import ABC, abstractmethod -from typing import Optional - -import numpy as np -import torch -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -import os - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.callbacks import Callback -from pytorch_lightning import _logger as log -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class TrainerLRFinderMixin(ABC): - @abstractmethod - def save_checkpoint(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def restore(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def _run_lr_finder_internally(self, model: LightningModule): - """ Call lr finder internally during Trainer.fit() """ - lr_finder = self.lr_find(model) - lr = lr_finder.suggestion() - # TODO: log lr.results to self.logger - if isinstance(self.auto_lr_find, str): - # Try to find requested field, may be nested - if _nested_hasattr(model.hparams, self.auto_lr_find): - _nested_setattr(model.hparams, self.auto_lr_find, lr) - else: - raise MisconfigurationException( - f'`auto_lr_find` was set to {self.auto_lr_find}, however' - ' could not find this as a field in `model.hparams`.') - else: - if hasattr(model.hparams, 'lr'): - model.hparams.lr = lr - elif hasattr(model.hparams, 'learning_rate'): - model.hparams.learning_rate = lr - else: - raise MisconfigurationException( - 'When auto_lr_find is set to True, expects that hparams' - ' either has field `lr` or `learning_rate` that can overridden') - log.info(f'Learning rate set to {lr}') - - def lr_find(self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[DataLoader] = None, - min_lr: float = 1e-8, - max_lr: float = 1, - num_training: int = 100, - mode: str = 'exponential', - num_accumulation_steps: int = 1): - r""" - lr_find enables the user to do a range test of good initial learning rates, - to reduce the amount of guesswork in picking a good starting learning rate. - - Args: - model: Model to do range testing for - - train_dataloader: A PyTorch - DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. - - min_lr: minimum learning rate to investigate - - max_lr: maximum learning rate to investigate - - num_training: number of learning rates to test - - mode: search strategy, either 'linear' or 'exponential'. If set to - 'linear' the learning rate will be searched by linearly increasing - after each batch. If set to 'exponential', will increase learning - rate exponentially. - - num_accumulation_steps: number of batches to calculate loss over. - - Example:: - - # Setup model and trainer - model = MyModelClass(hparams) - trainer = pl.Trainer() - - # Run lr finder - lr_finder = trainer.lr_find(model, ...) - - # Inspect results - fig = lr_finder.plot(); fig.show() - suggested_lr = lr_finder.suggestion() - - # Overwrite lr and create new model - hparams.lr = suggested_lr - model = MyModelClass(hparams) - - # Ready to train with new learning rate - trainer.fit(model) - - """ - save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') - - self._lr_finder_dump_params(model) - - # Prevent going into infinite loop - self.auto_lr_find = False - - # Initialize lr finder object (stores results) - lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) - - # Use special lr logger callback - self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)] - - # No logging - self.logger = None - - # Max step set to number of iterations - self.max_steps = num_training - - # Disable standard progress bar for fit - if self.progress_bar_callback: - self.progress_bar_callback.disable() - - # Accumulation of gradients - self.accumulate_grad_batches = num_accumulation_steps - - # Disable standard checkpoint & early stopping - self.checkpoint_callback = False - self.early_stop_callback = None - self.enable_early_stop = False - - # Required for saving the model - self.optimizers, self.schedulers = [], [], - self.model = model - - # Dump model checkpoint - self.save_checkpoint(str(save_path)) - - # Configure optimizer and scheduler - optimizers, _, _ = self.init_optimizers(model) - - if len(optimizers) != 1: - raise MisconfigurationException( - f'`model.configure_optimizers()` returned {len(optimizers)}, but' - ' learning rate finder only works with single optimizer') - configure_optimizers = model.configure_optimizers - model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0]) - - # Fit, lr & loss logged in callback - self.fit(model, - train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders) - - # Prompt if we stopped early - if self.global_step != num_training: - log.info('LR finder stopped early due to diverging loss.') - - # Transfer results from callback to lr finder object - lr_finder.results.update({'lr': self.callbacks[0].lrs, - 'loss': self.callbacks[0].losses}) - - # Reset model state - self.restore(str(save_path), on_gpu=self.on_gpu) - os.remove(save_path) - - # Finish by resetting variables so trainer is ready to fit model - self._lr_finder_restore_params(model) - if self.progress_bar_callback: - self.progress_bar_callback.enable() - - return lr_finder - - def _lr_finder_dump_params(self, model): - # Prevent going into infinite loop - self._params = { - 'auto_lr_find': self.auto_lr_find, - 'callbacks': self.callbacks, - 'logger': self.logger, - 'max_steps': self.max_steps, - 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, - 'accumulate_grad_batches': self.accumulate_grad_batches, - 'checkpoint_callback': self.checkpoint_callback, - 'early_stop_callback': self.early_stop_callback, - 'enable_early_stop': self.enable_early_stop, - 'progress_bar_callback': self.progress_bar_callback, - 'configure_optimizers': model.configure_optimizers, - } - - def _lr_finder_restore_params(self, model): - self.auto_lr_find = self._params['auto_lr_find'] - self.logger = self._params['logger'] - self.callbacks = self._params['callbacks'] - self.max_steps = self._params['max_steps'] - self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate'] - self.accumulate_grad_batches = self._params['accumulate_grad_batches'] - self.checkpoint_callback = self._params['checkpoint_callback'] - self.early_stop_callback = self._params['early_stop_callback'] - self.enable_early_stop = self._params['enable_early_stop'] - self.progress_bar_callback = self._params['progress_bar_callback'] - model.configure_optimizers = self._params['configure_optimizers'] - - -class _LRFinder(object): - """ LR finder object. This object stores the results of Trainer.lr_find(). - - Args: - mode: either `linear` or `exponential`, how to increase lr after each step - - lr_min: lr to start search from - - lr_max: lr to stop seach - - num_training: number of steps to take between lr_min and lr_max - - Example:: - # Run lr finder - lr_finder = trainer.lr_find(model) - - # Results stored in - lr_finder.results - - # Plot using - lr_finder.plot() - - # Get suggestion - lr = lr_finder.suggestion() - """ - def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): - assert mode in ('linear', 'exponential'), \ - 'mode should be either `linear` or `exponential`' - - self.mode = mode - self.lr_min = lr_min - self.lr_max = lr_max - self.num_training = num_training - - self.results = {} - - def _get_new_optimizer(self, optimizer: torch.optim.Optimizer): - """ Construct a new `configure_optimizers()` method, that has a optimizer - with initial lr set to lr_min and a scheduler that will either - linearly or exponentially increase the lr to lr_max in num_training steps. - - Args: - optimizer: instance of `torch.optim.Optimizer` - - """ - new_lrs = [self.lr_min] * len(optimizer.param_groups) - for param_group, new_lr in zip(optimizer.param_groups, new_lrs): - param_group["lr"] = new_lr - param_group["initial_lr"] = new_lr - - args = (optimizer, self.lr_max, self.num_training) - scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args) - - def configure_optimizers(): - return [optimizer], [{'scheduler': scheduler, - 'interval': 'step'}] - - return configure_optimizers - - def plot(self, suggest: bool = False, show: bool = False): - """ Plot results from lr_find run - Args: - suggest: if True, will mark suggested lr to use with a red point - - show: if True, will show figure - """ - import matplotlib.pyplot as plt - - lrs = self.results["lr"] - losses = self.results["loss"] - - fig, ax = plt.subplots() - - # Plot loss as a function of the learning rate - ax.plot(lrs, losses) - if self.mode == 'exponential': - ax.set_xscale("log") - ax.set_xlabel("Learning rate") - ax.set_ylabel("Loss") - - if suggest: - _ = self.suggestion() - if self._optimal_idx: - ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], - markersize=10, marker='o', color='red') - - if show: - plt.show() - - return fig - - def suggestion(self): - """ This will propose a suggestion for choice of initial learning rate - as the point with the steepest negative gradient. - - Returns: - lr: suggested initial learning rate to use - - """ - try: - min_grad = (np.gradient(np.array(self.results["loss"]))).argmin() - self._optimal_idx = min_grad - return self.results["lr"][min_grad] - except Exception: - log.warning('Failed to compute suggesting for `lr`.' - ' There might not be enough points.') - self._optimal_idx = None - - -class _LRCallback(Callback): - """ Special callback used by the learning rate finder. This callbacks log - the learning rate before each batch and log the corresponding loss after - each batch. """ - def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98): - self.num_training = num_training - self.beta = beta - self.losses = [] - self.lrs = [] - self.avg_loss = 0.0 - self.best_loss = 0.0 - self.progress_bar_refresh_rate = progress_bar_refresh_rate - self.progress_bar = None - - def on_batch_start(self, trainer, pl_module): - """ Called before each training batch, logs the lr that will be used """ - if self.progress_bar_refresh_rate and self.progress_bar is None: - self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) - - self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) - - def on_batch_end(self, trainer, pl_module): - """ Called when the training batch ends, logs the calculated loss """ - if self.progress_bar: - self.progress_bar.update() - - current_loss = trainer.running_loss.last().item() - current_step = trainer.global_step + 1 # remove the +1 in 1.0 - - # Avg loss (loss with momentum) + smoothing - self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss - smoothed_loss = self.avg_loss / (1 - self.beta**current_step) - - # Check if we diverging - if current_step > 1 and smoothed_loss > 4 * self.best_loss: - trainer.max_steps = current_step # stop signal - if self.progress_bar: - self.progress_bar.close() - - # Save best loss for diverging checking - if smoothed_loss < self.best_loss or current_step == 1: - self.best_loss = smoothed_loss - - self.losses.append(smoothed_loss) - - -class _LinearLR(_LRScheduler): - """Linearly increases the learning rate between two boundaries - over a number of iterations. - Arguments: - - optimizer: wrapped optimizer. - - end_lr: the final learning rate. - - num_iter: the number of iterations over which the test occurs. - - last_epoch: the index of last epoch. Default: -1. - """ - - def __init__(self, - optimizer: torch.optim.Optimizer, - end_lr: float, - num_iter: int, - last_epoch: int = -1): - self.end_lr = end_lr - self.num_iter = num_iter - super(_LinearLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - curr_iter = self.last_epoch + 1 - r = curr_iter / self.num_iter - - if self.last_epoch > 0: - val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] - else: - val = [base_lr for base_lr in self.base_lrs] - self._lr = val - return val - - @property - def lr(self): - return self._lr - - -class _ExponentialLR(_LRScheduler): - """Exponentially increases the learning rate between two boundaries - over a number of iterations. - - Arguments: - - optimizer: wrapped optimizer. - - end_lr: the final learning rate. - - num_iter: the number of iterations over which the test occurs. - - last_epoch: the index of last epoch. Default: -1. - """ - - def __init__(self, - optimizer: torch.optim.Optimizer, - end_lr: float, - num_iter: int, - last_epoch: int = -1): - self.end_lr = end_lr - self.num_iter = num_iter - super(_ExponentialLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - curr_iter = self.last_epoch + 1 - r = curr_iter / self.num_iter - - if self.last_epoch > 0: - val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] - else: - val = [base_lr for base_lr in self.base_lrs] - self._lr = val - return val - - @property - def lr(self): - return self._lr - - -def _nested_hasattr(obj, path): - parts = path.split(".") - for part in parts: - if hasattr(obj, part): - obj = getattr(obj, part) - else: - return False - else: - return True - - -def _nested_setattr(obj, path, val): - parts = path.split(".") - for part in parts[:-1]: - if hasattr(obj, part): - obj = getattr(obj, part) - setattr(obj, parts[-1], val) diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index fa9dfed7da02fe..b924675d8505c8 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -1,45 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect -from abc import ABC, abstractmethod +from abc import ABC +from typing import Optional from pytorch_lightning.core.lightning import LightningModule class TrainerModelHooksMixin(ABC): - def is_function_implemented(self, f_name): - model = self.get_model() - f_op = getattr(model, f_name, None) - return callable(f_op) + lightning_module: LightningModule - def is_overriden(self, method_name: str, model: LightningModule = None) -> bool: + def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + # note: currently unused - kept as it is public if model is None: - model = self.get_model() - super_object = LightningModule - - if not hasattr(model, method_name): - # in case of calling deprecated method - return False - - instance_attr = getattr(model, method_name) - if not instance_attr: - return False - super_attr = getattr(super_object, method_name) - - # when code pointers are different, it was implemented - if hasattr(instance_attr, 'patch_loader_code'): - # cannot pickle __code__ so cannot verify if PatchDataloader - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__) - else: - is_overriden = instance_attr.__code__ is not super_attr.__code__ - return is_overriden + model = self.lightning_module + f_op = getattr(model, f_name, None) + return callable(f_op) - def has_arg(self, f_name, arg_name): - model = self.get_model() + def has_arg(self, f_name: str, arg_name: str) -> bool: + model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters - - @abstractmethod - def get_model(self): - """Warning: this is just empty shell for code implemented in other class.""" diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 8dd77f7971a481..a247fb92cd22f8 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -1,111 +1,149 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch from torch import optim from torch.optim.optimizer import Optimizer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class TrainerOptimizersMixin(ABC): - def init_optimizers( - self, - model: LightningModule - ) -> Tuple[List, List, List]: - optim_conf = model.configure_optimizers() + _lightning_optimizers: Optional[List[LightningOptimizer]] + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + self._lightning_optimizers = None + optim_conf = model.configure_optimizers() if optim_conf is None: - rank_zero_warn('`LightningModule.configure_optimizers` returned `None`, ' - 'this fit will run with no optimizer', UserWarning) + rank_zero_warn( + '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer', + UserWarning, + ) optim_conf = _MockOptimizer() + optimizers, lr_schedulers, optimizer_frequencies = [], [], [] + monitor = None + # single output, single optimizer if isinstance(optim_conf, Optimizer): - return [optim_conf], [], [] - + optimizers = [optim_conf] # two lists, optimizer + lr schedulers - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ - and isinstance(optim_conf[0], list): - optimizers, lr_schedulers = optim_conf - lr_schedulers = self.configure_schedulers(lr_schedulers) - return optimizers, lr_schedulers, [] - + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): + opt, sch = optim_conf + optimizers = opt + lr_schedulers = sch if isinstance(sch, list) else [sch] # single dictionary elif isinstance(optim_conf, dict): - optimizer = optim_conf["optimizer"] - lr_scheduler = optim_conf.get("lr_scheduler", []) - if lr_scheduler: - lr_schedulers = self.configure_schedulers([lr_scheduler]) - else: - lr_schedulers = [] - return [optimizer], lr_schedulers, [] - + optimizers = [optim_conf["optimizer"]] + monitor = optim_conf.get('monitor', None) + lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else [] # multiple dictionaries - elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - # take only lr wif exists and ot they are defined - not None - lr_schedulers = [ - opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler") - ] - # take only freq wif exists and ot they are defined - not None + lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict] optimizer_frequencies = [ - opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") is not None + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None ] - - # clean scheduler list - if lr_schedulers: - lr_schedulers = self.configure_schedulers(lr_schedulers) # assert that if frequencies are present, they are given for all optimizers if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): raise ValueError("A frequency must be given to each optimizer.") - return optimizers, lr_schedulers, optimizer_frequencies - # single list or tuple, multiple optimizer elif isinstance(optim_conf, (list, tuple)): - return list(optim_conf), [], [] - + optimizers = list(optim_conf) # unknown configuration else: - raise ValueError( + raise MisconfigurationException( 'Unknown configuration for model optimizers.' - ' Output from `model.configure_optimizers()` should either be:' - ' * single output, single `torch.optim.Optimizer`' - ' * single output, list of `torch.optim.Optimizer`' - ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' - ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' - ' * two outputs, first being a list of `torch.optim.Optimizer` second being' - ' a list of `torch.optim.lr_scheduler`' - ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') - - def configure_schedulers(self, schedulers: list): - # Convert each scheduler into dict sturcture with relevant information + ' Output from `model.configure_optimizers()` should either be:\n' + ' * `torch.optim.Optimizer`\n' + ' * [`torch.optim.Optimizer`]\n' + ' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n' + ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n' + ' * A list of the previously described dict format, with an optional "frequency" key (int)' + ) + + lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor) + _validate_scheduler_optimizer(optimizers, lr_schedulers) + + return optimizers, lr_schedulers, optimizer_frequencies + + def convert_to_lightning_optimizers(self): + + def _convert_to_lightning_optimizer(trainer, optimizer): + if not isinstance(optimizer, LightningOptimizer): + optimizer = LightningOptimizer(optimizer) + optimizer._on_trainer_init(trainer) + return optimizer + + self._lightning_optimizers = { + opt_idx: _convert_to_lightning_optimizer(self, opt) + for opt_idx, opt in enumerate(self.optimizers) + } + + def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): + # Convert each scheduler into dict structure with relevant information lr_schedulers = [] - default_config = {'interval': 'epoch', # default every epoch - 'frequency': 1, # default every epoch/batch - 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler - 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau + default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): + # check provided keys + extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] + if extra_keys: + rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning) if 'scheduler' not in scheduler: - raise ValueError(f'Lr scheduler should have key `scheduler`', - ' with item being a lr scheduler') - scheduler['reduce_on_plateau'] = isinstance( - scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) + raise MisconfigurationException( + 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' + ) + if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'): + raise MisconfigurationException( + f'The "interval" key in lr scheduler dict must be "step" or "epoch"' + f' but is "{scheduler["interval"]}"' + ) + scheduler['reduce_on_plateau'] = isinstance( + scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau + ) + if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None: + raise MisconfigurationException( + 'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.' + ' For example: {"optimizer": optimizer, "lr_scheduler":' + ' {"scheduler": scheduler, "monitor": "your_loss"}}' + ) lr_schedulers.append({**default_config, **scheduler}) - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): - lr_schedulers.append({**default_config, 'scheduler': scheduler, - 'reduce_on_plateau': True}) - + if monitor is None: + raise MisconfigurationException( + '`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.' + ' For example:' + ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' + ) + lr_schedulers.append({ + **default_config, 'scheduler': scheduler, + 'reduce_on_plateau': True, + 'monitor': monitor + }) elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): lr_schedulers.append({**default_config, 'scheduler': scheduler}) else: - raise ValueError(f'Input {scheduler} to lr schedulers ' - 'is a invalid input.') + raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers @@ -135,3 +173,22 @@ def zero_grad(self): def __repr__(self): return 'No Optimizer' + + +def _validate_scheduler_optimizer(optimizers, lr_schedulers): + if any(sch['scheduler'].optimizer not in optimizers for sch in lr_schedulers): + raise MisconfigurationException( + "Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`." + ) + + +def _get_default_scheduler_config() -> Dict[str, Any]: + return { + 'scheduler': None, + 'name': None, # no custom name + 'interval': 'epoch', # after epoch is over + 'frequency': 1, # every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': None, # value to monitor for ReduceLROnPlateau + 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau + } diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py new file mode 100644 index 00000000000000..b33f41cb2ea487 --- /dev/null +++ b/pytorch_lightning/trainer/predict_loop.py @@ -0,0 +1,113 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.warnings import WarningCache + + +class PredictLoop(object): + + def __init__(self, trainer): + self.trainer = trainer + self.max_batches = None + self.num_dataloaders = None + self.warning_cache = WarningCache() + + def on_trainer_init(self): + self.trainer.num_predict_batches = [] + + def get_predict_dataloaders(self): + self.trainer.reset_predict_dataloader(self.trainer.lightning_module) + + dataloaders = self.trainer.predict_dataloaders + max_batches = self.trainer.num_predict_batches + + return dataloaders, max_batches + + def should_skip_predict(self, max_batches): + return sum(max_batches) == 0 + + def on_predict_model_eval(self, *_, **__): + model_ref = self.trainer.lightning_module + model_ref.on_predict_model_eval() + + def setup(self, model, max_batches, dataloaders): + self.trainer.call_hook("on_predict_start") + + # copy properties for forward overrides + self.trainer.model_connector.copy_trainer_model_properties(model) + + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) + + self.max_batches = max_batches + self.num_dataloaders = self._get_num_dataloaders(dataloaders) + self._predictions = [[] for _ in range(self.num_dataloaders)] + + self.trainer._progress_bar_callback.on_predict_start(self.trainer, self.trainer.lightning_module) + + def _get_num_dataloaders(self, dataloaders): + # case where user does: + # return dl1, dl2 + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + + def predict_step(self, batch, batch_idx, dataloader_idx): + # configure args + args = [batch, batch_idx] + if self.num_dataloaders: + args.append(dataloader_idx) + + model_ref = self.trainer.lightning_module + + model_ref._current_fx_name = "predict" + predictions = self.trainer.accelerator.predict_step(args) + + if predictions is None: + self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") + + self._predictions[dataloader_idx].append(predictions) + self.trainer._progress_bar_callback.on_predict_batch_end( + self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx + ) + return + + def on_predict_epoch_end(self): + self.trainer.profiler.describe() + + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) + + results = self._predictions + + def _convert_to_numpy(v): + return v.cpu().numpy() + + results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) + + if len(results) == 1: + return results[0] + + return results + + def on_predict_start(self): + # hook + self.trainer.call_hook("on_predict_start") + + def on_predict_end(self): + # hook + self.trainer.call_hook("on_predict_end") diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py new file mode 100644 index 00000000000000..315e3c60c05579 --- /dev/null +++ b/pytorch_lightning/trainer/properties.py @@ -0,0 +1,506 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from abc import ABC +from argparse import ArgumentParser, Namespace +from typing import cast, List, Optional, Type, TypeVar, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn +from pytorch_lightning.utilities.argparse import ( + add_argparse_args, + from_argparse_args, + parse_argparser, + parse_env_variables, +) +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.model_helpers import is_overridden + + +class TrainerProperties(ABC): + + _default_root_dir: str + _lightning_optimizers = None + _progress_bar_callback: ProgressBarBase + _running_stage: Optional[RunningStage] = None + _state: TrainerState + _weights_save_path: str + + accelerator_connector: AcceleratorConnector + callbacks: List[Callback] + checkpoint_connector: CheckpointConnector + limit_val_batches: int + logger: LightningLoggerBase + logger_connector: LoggerConnector + + @property + def accelerator(self) -> Accelerator: + return self.accelerator_connector.accelerator + + @property + def distributed_backend(self) -> Optional[str]: + # for backward compatibility + return self.accelerator_connector.distributed_backend + + @property + def training_type_plugin(self) -> TrainingTypePlugin: + return self.accelerator.training_type_plugin + + @property + def precision_plugin(self) -> PrecisionPlugin: + return self.accelerator.precision_plugin + + @property + def global_rank(self) -> int: + return self.accelerator.training_type_plugin.global_rank + + @property + def local_rank(self) -> int: + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + + @property + def node_rank(self) -> int: + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + + @property + def world_size(self) -> int: + # some training types define a world size + return getattr(self.accelerator.training_type_plugin, "world_size", 1) + + @property + def _distrib_type(self) -> DistributedType: + return self.accelerator_connector._distrib_type + + @property + def _device_type(self) -> DeviceType: + return self.accelerator_connector._device_type + + @property + def num_nodes(self) -> int: + return self.accelerator_connector.num_nodes + + @property + def num_processes(self) -> int: + return self.accelerator_connector.num_processes + + @property + def root_gpu(self) -> Optional[int]: + return self.accelerator_connector.root_gpu + + @property + def tpu_cores(self) -> int: + return self.accelerator_connector.tpu_cores + + @property + def num_gpus(self) -> int: + return self.accelerator_connector.num_gpus + + @property + def data_parallel_device_ids(self) -> Optional[List[int]]: + return self.accelerator_connector.parallel_device_ids + + @property + def log_dir(self) -> Optional[str]: + if self.logger is None: + dirpath = self.default_root_dir + else: + dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir') + + dirpath = self.accelerator.broadcast(dirpath) + return dirpath + + @property + def use_amp(self) -> bool: + return self.precision == 16 + + @property + def callback_metrics(self) -> dict: + return self.logger_connector.callback_metrics + + @callback_metrics.setter + def callback_metrics(self, x: dict) -> None: + self.logger_connector.callback_metrics = x + + @property + def logged_metrics(self) -> dict: + return self.logger_connector.logged_metrics + + @logged_metrics.setter + def logged_metrics(self, x: dict) -> None: + self.logger_connector.logged_metrics = x + + @property + def progress_bar_metrics(self) -> dict: + return self.logger_connector.progress_bar_metrics + + @progress_bar_metrics.setter + def progress_bar_metrics(self, x: dict) -> None: + self.logger_connector.progress_bar_metrics = x + + @property + def state(self) -> TrainerState: + return self._state + + @state.setter + def state(self, state: TrainerState) -> None: + self._state = state + + @property + def interrupted(self) -> bool: + return self._state == TrainerState.INTERRUPTED + + @property + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @property + def slurm_job_id(self) -> Optional[int]: + job_id = os.environ.get('SLURM_JOB_ID') + if job_id: + try: + job_id = int(job_id) + except ValueError: + job_id = None + + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = os.environ.get('SLURM_JOB_NAME') == 'bash' + if in_slurm_interactive_mode: + job_id = None + return job_id + + @classmethod + def default_attributes(cls) -> dict: + init_signature = inspect.signature(cls) + + args = {} + for param_name in init_signature.parameters: + value = init_signature.parameters[param_name].default + args[param_name] = value + + return args + + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated Trainer arguments.""" + depr_arg_names = [] + for name, val in cls.__dict__.items(): + if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names + + @classmethod + def from_argparse_args(cls: Type['_T'], args: Union[Namespace, ArgumentParser], **kwargs) -> '_T': + return from_argparse_args(cls, args, **kwargs) + + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + return parse_argparser(cls, arg_parser) + + @classmethod + def match_env_arguments(cls) -> Namespace: + return parse_env_variables(cls) + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + return add_argparse_args(cls, parent_parser, **kwargs) + + @property + def gpus(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.gpus + + @property + def data_parallel(self) -> bool: + return self._distrib_type in ( + DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 + ) + + @property + def progress_bar_callback(self) -> Optional[ProgressBarBase]: + return self._progress_bar_callback + + @property + def progress_bar_dict(self) -> dict: + """ Read-only for progress bar metrics. """ + ref_model = self.lightning_module + ref_model = cast(LightningModule, ref_model) + + standard_metrics = ref_model.get_progress_bar_dict() + logged_metrics = self.progress_bar_metrics + duplicates = list(standard_metrics.keys() & logged_metrics.keys()) + if duplicates: + rank_zero_warn( + f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" + f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " + f" If this is undesired, change the name or override `get_progress_bar_dict()`" + f" in `LightingModule`.", UserWarning + ) + all_metrics = dict(**standard_metrics) + all_metrics.update(**logged_metrics) + return all_metrics + + @property + def disable_validation(self) -> bool: + """ Check if validation is disabled during training. """ + return not self.enable_validation + + @property + def enable_validation(self) -> bool: + """ Check if we should run validation during training. """ + model_ref = self.lightning_module + val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 + return val_loop_enabled + + @property + def default_root_dir(self) -> str: + """ + The default location to save artifacts of loggers, checkpoints etc. + It is used as a fallback if logger or checkpoint callback do not define specific save paths. + """ + if get_filesystem(self._default_root_dir).protocol == "file": + return os.path.normpath(self._default_root_dir) + return self._default_root_dir + + @property + def weights_save_path(self) -> str: + """ + The default root location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. + """ + if get_filesystem(self._weights_save_path).protocol == "file": + return os.path.normpath(self._weights_save_path) + return self._weights_save_path + + @property + def early_stopping_callback(self) -> Optional[EarlyStopping]: + """ + The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` + callback in the Trainer.callbacks list, or ``None`` if it doesn't exist. + """ + callbacks = self.early_stopping_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def early_stopping_callbacks(self) -> List[EarlyStopping]: + """ + A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` + found in the Trainer.callbacks list. + """ + return [c for c in self.callbacks if isinstance(c, EarlyStopping)] + + @property + def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + """ + The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + callback in the Trainer.callbacks list, or ``None`` if it doesn't exist. + """ + callbacks = self.checkpoint_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + """ + A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + found in the Trainer.callbacks list. + """ + return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: + self.checkpoint_connector.save_checkpoint(filepath, weights_only) + + @property + def model(self) -> torch.nn.Module: + """ + The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + To access the pure LightningModule, use + :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. + """ + return self.accelerator.model + + @model.setter + def model(self, model: torch.nn.Module) -> None: + """ + Setter for the model, pass-through to accelerator and plugin where the model reference is stored. + Used by the Tuner to reset the state of Trainer and Accelerator. + + Args: + model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending + on the backend. + """ + self.accelerator.model = model + + @property + def lightning_optimizers(self) -> List[LightningOptimizer]: + if self._lightning_optimizers is None: + self.convert_to_lightning_optimizers() + return self._lightning_optimizers + + @property + def lightning_module(self) -> LightningModule: + return self.accelerator.lightning_module + + @property + def optimizers(self) -> Optional[List[Optimizer]]: + return self.accelerator.optimizers + + @optimizers.setter + def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: + # Necessary to rewrap optimizers to lightning + # They will be re-created when accessing + # the `lightning_optimizers` trainer property + self._lightning_optimizers = None + + self.accelerator.optimizers = new_optims + + @property + def lr_schedulers(self) -> Optional[list]: + return self.accelerator.lr_schedulers + + @lr_schedulers.setter + def lr_schedulers(self, new_schedulers: Optional[list]) -> None: + self.accelerator.lr_schedulers = new_schedulers + + @property + def optimizer_frequencies(self) -> list: + return self.accelerator.optimizer_frequencies + + @optimizer_frequencies.setter + def optimizer_frequencies(self, new_freqs: list) -> None: + self.accelerator.optimizer_frequencies = new_freqs + + @property + def amp_backend(self) -> Optional[str]: + return self.accelerator.amp_backend + + @property + def precision(self) -> Union[str, int]: + return self.accelerator.precision + + @property + def scaler(self): + return self.accelerator.scaler + + # TODO: refactor this so that it can be done in LightningOptimizer + def __getstate__(self): + # remove lightning_optimizers + self._lightning_optimizers = None + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + + @property + def distributed_sampler_kwargs(self) -> Optional[dict]: + if isinstance(self.training_type_plugin, ParallelPlugin): + return self.training_type_plugin.distributed_sampler_kwargs + + @property + def training(self) -> bool: + return self._running_stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TRAINING + elif self.training: + self._running_stage = None + + @property + def testing(self) -> bool: + return self._running_stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TESTING + elif self.testing: + self._running_stage = None + + @property + def predicting(self) -> bool: + return self._running_stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.PREDICTING + elif self.predicting: + self._running_stage = None + + @property + def tuning(self) -> bool: + return self._running_stage == RunningStage.TUNING + + @tuning.setter + def tuning(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.TUNING + elif self.tuning: + self._running_stage = None + + @property + def validating(self) -> bool: + return self._running_stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.VALIDATING + elif self.validating: + self._running_stage = None + + @property + def evaluating(self) -> bool: + return self._running_stage and self._running_stage.evaluating + + @property + def sanity_checking(self) -> bool: + return self._running_stage == RunningStage.SANITY_CHECKING + + @sanity_checking.setter + def sanity_checking(self, val: bool) -> None: + if val: + self._running_stage = RunningStage.SANITY_CHECKING + elif self.sanity_checking: + self._running_stage = None + + @property + def _setup_state(self) -> TrainerState: + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + + @property + def _teardown_state(self) -> Optional[TrainerState]: + if self.state.running: + return self._setup_state + + +# Used to represent the concrete type TrainerProperties class methods are called on. +_T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py new file mode 100644 index 00000000000000..b1f188ab047fe2 --- /dev/null +++ b/pytorch_lightning/trainer/states.py @@ -0,0 +1,68 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities import LightningEnum + + +class TrainerState(LightningEnum): + """ State for the :class:`~pytorch_lightning.trainer.trainer.Trainer` + to indicate what is currently or was executed. It follows the user-called + functions such as `trainer.fit()` and `trainer.test(). + + >>> # you can compare the type with a string + >>> TrainerState.FITTING == 'fit' + True + >>> # which is case insensitive + >>> TrainerState.FINISHED == 'FINISHED' + True + """ + INITIALIZING = 'initializing' # trainer creation + FITTING = 'fit' # trainer.fit() + VALIDATING = 'validate' # trainer.validate() + TESTING = 'test' # trainer.test() + PREDICTING = 'predict' # trainer.predict() + TUNING = 'tune' # trainer.tune() + FINISHED = 'finished' + INTERRUPTED = 'interrupted' + + @property + def stopped(self) -> bool: + return self in (self.FINISHED, self.INTERRUPTED) + + @property + def running(self) -> bool: + return self in (self.FITTING, self.VALIDATING, self.TESTING, self.PREDICTING, self.TUNING) + + +class RunningStage(LightningEnum): + """Current running stage. + + This stage complements :class:`TrainerState` for example to indicate that + `RunningStage.VALIDATING` will be set both during `TrainerState.FITTING` + and `TrainerState.VALIDATING`. It follows the internal code logic. + + >>> # you can match the Enum with string + >>> RunningStage.TRAINING == 'train' + True + """ + TRAINING = 'train' + SANITY_CHECKING = 'sanity_check' + VALIDATING = 'validate' + TESTING = 'test' + PREDICTING = 'predict' + TUNING = 'tune' + + @property + def evaluating(self) -> bool: + return self in (self.VALIDATING, self.TESTING) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index dc29e8d36a08dc..f884306dc09c87 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -1,4 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import Any, Callable, Optional, Union + import torch +from torch import Tensor +from torch.utils.data import Dataset + +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.data import get_len +from pytorch_lightning.utilities.exceptions import MisconfigurationException class TensorRunningAccum(object): @@ -23,14 +48,14 @@ class TensorRunningAccum(object): def __init__(self, window_length: int): self.window_length = window_length - self.memory = torch.Tensor(self.window_length) + self.memory = None self.current_idx: int = 0 - self.last_idx: int = None + self.last_idx: Optional[int] = None self.rotated: bool = False def reset(self) -> None: """Empty the accumulator.""" - self = TensorRunningAccum(self.window_length) + self.__init__(self.window_length) def last(self): """Get the last added element.""" @@ -39,6 +64,9 @@ def last(self): def append(self, x): """Add an element to the accumulator.""" + if self.memory is None: + self.memory = torch.zeros(self.window_length, *x.shape) + # ensure same device and type if self.memory.device != x.device or self.memory.type() != x.type(): x = x.to(self.memory) @@ -74,3 +102,407 @@ def _agg_memory(self, how: str): return getattr(self.memory, how)() else: return getattr(self.memory[:self.current_idx], how)() + + +class PredictionCollection(object): + + def __init__(self, global_rank: int, world_size: int): + self.global_rank = global_rank + self.world_size = world_size + self.predictions = {} + self.num_predictions = 0 + + def _add_prediction(self, name, values, filename): + if filename not in self.predictions: + self.predictions[filename] = {name: values} + elif name not in self.predictions[filename]: + self.predictions[filename][name] = values + elif isinstance(values, Tensor): + self.predictions[filename][name] = torch.cat((self.predictions[filename][name], values)) + elif isinstance(values, list): + self.predictions[filename][name].extend(values) + + def add(self, predictions): + + if predictions is None: + return + + for filename, pred_dict in predictions.items(): + for feature_name, values in pred_dict.items(): + self._add_prediction(feature_name, values, filename) + + def to_disk(self) -> None: + """Write predictions to file(s). + """ + for filepath, predictions in self.predictions.items(): + fs = get_filesystem(filepath) + # normalize local filepaths only + if fs.protocol == "file": + filepath = os.path.realpath(filepath) + if self.world_size > 1: + stem, extension = os.path.splitext(filepath) + filepath = f"{stem}_rank_{self.global_rank}{extension}" + dirpath = os.path.split(filepath)[0] + fs.mkdirs(dirpath, exist_ok=True) + + # Convert any tensor values to list + predictions = {k: v if not isinstance(v, Tensor) else v.tolist() for k, v in predictions.items()} + + # Check if all features for this file add up to same length + feature_lens = {k: len(v) for k, v in predictions.items()} + if len(set(feature_lens.values())) != 1: + raise ValueError("Mismatching feature column lengths found in stored EvalResult predictions.") + + # Switch predictions so each entry has its own dict + outputs = [] + for values in zip(*predictions.values()): + output_element = {k: v for k, v in zip(predictions.keys(), values)} + outputs.append(output_element) + + # Write predictions for current file to disk + with fs.open(filepath, "wb") as fp: + torch.save(outputs, fp) + + +class CycleIterator(object): + """ + Iterator for restarting a dataloader if it runs out of samples + """ + + def __init__(self, loader: Any, length: Optional[int] = None): + """ + + Args: + loader: the loader to restart for cyclic (and optionally infinite) sampling + length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration + if None: infinite + + """ + if length is None: + length = float('inf') + + self.length = length + self.loader = loader + self._loader_iter = None + self.counter = 0 + + def __iter__(self) -> Any: + """ + Creates the internal iterator and returns self + + Returns: + CycleIterator: self + + """ + self.counter = 0 + self._loader_iter = iter(self.loader) + return self + + def __next__(self) -> Any: + """ + Fetches the next batch from internal dataloader and restarts + it if necessary + + Returns: + Any: the resulting batch + + Raises: + StopIteration: if more then :attr:`length` batches have been returned + + """ + # Note: if self.length is `inf`, then the iterator will never stop + if self.counter >= self.__len__(): + raise StopIteration + + try: + return next(self._loader_iter) + + except StopIteration: + self._loader_iter = iter(self.loader) + return next(self._loader_iter) + + finally: + self.counter += 1 + + def __len__(self) -> Union[int, float]: + return self.length + + +class CombinedDataset(object): + """ + Combine multiple datasets and compute their statistics + """ + COMPUTE_FUNCS = {'min_size': min, 'max_size_cycle': max} + + def __init__(self, datasets: Union[Sequence, Mapping], mode: str = 'min_size'): + """ + + Args: + datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset, + Iterable or even None. + mode: whether to use the minimum number of batches in all samples or the maximum + number of batches in all samples. + + """ + self.datasets = datasets + if mode not in self.COMPUTE_FUNCS.keys(): + raise MisconfigurationException( + f'You have selected unsupported mode "{mode}",' + f' please select one the: {list(self.COMPUTE_FUNCS.keys())}.' + ) + self.mode = mode + + @property + def max_len(self) -> Union[int, float]: + return self._calc_num_data(self.datasets, 'max_size_cycle') + + @property + def min_len(self) -> Union[int, float]: + return self._calc_num_data(self.datasets, 'min_size') + + @staticmethod + def _calc_num_data(datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: + """ + Compute the length of `CombinedDataset` according to the `mode`. + + Args: + datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, + Iterable or even None. + mode: Determine `CombinedDataset`'s length is the maximum or minimum of + the datasets. + + Returns: + length: the length of `CombinedDataset` + + """ + if mode not in CombinedDataset.COMPUTE_FUNCS.keys(): + raise MisconfigurationException(f"Invalid Mode: {mode}") + + # extract the lengths + all_lengths = apply_to_collection( + datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping) + ) + + compute_func = CombinedDataset.COMPUTE_FUNCS[mode] + + if isinstance(all_lengths, (int, float)): + length = all_lengths + else: + length = _nested_calc_num_data(all_lengths, compute_func) + + return length + + def __len__(self) -> int: + """Return the minimum length of the datasets.""" + return self._calc_num_data(self.datasets, self.mode) + + +class CombinedLoader(object): + """ + Combines different dataloaders and allows sampling in parallel. + + Supported modes are 'min_size', which raises StopIteration after the shortest loader + (the one with the lowest number of batches) is done, and 'max_size_cycle` which raises + StopIteration after the longest loader (the one with most batches) is done, while cycling + through the shorter loaders. + + Examples: + >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), + ... 'b': torch.utils.data.DataLoader(range(15), batch_size=5)} + >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') + >>> for item in combined_loader: + ... print(item) + {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} + {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} + >>> combined_loader = CombinedLoader(loaders, 'min_size') + >>> for item in combined_loader: + ... print(item) + {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} + {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + + """ + SUPPORTED_MODES = ('min_size', 'max_size_cycle') + + def __init__(self, loaders: Any, mode: str = 'min_size'): + """ + + Args: + loaders: the loaders to sample from. Can be all kind of collection + mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and + 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones. + + """ + self.loaders = loaders + + datasets = apply_to_collection( + self.loaders, Iterable, getattr, 'dataset', None, wrong_dtype=(Sequence, Mapping) + ) + # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader + self.dataset = CombinedDataset(datasets, mode) + + if mode not in self.SUPPORTED_MODES: + raise MisconfigurationException(f"Invalid Mode: {mode}") + + self.mode = mode + + if self.mode == 'max_size_cycle': + self._wrap_loaders_max_size_cycle() + + @property + def sampler(self) -> Union[Iterable, Sequence, Mapping]: + """Return a collections of samplers extracting from loaders.""" + return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, wrong_dtype=(Sequence, Mapping)) + + def _wrap_loaders_max_size_cycle(self) -> Any: + """ + Wraps all loaders to make sure they are cycled until the longest loader is exhausted + + Returns: + the wrapped loaders + + """ + all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) + + if isinstance(all_lengths, (int, float)): + length = all_lengths + + elif isinstance(all_lengths, Mapping): + length = max(all_lengths.values()) + + elif isinstance(all_lengths, Sequence): + length = max(all_lengths) + + if isinstance(self.loaders, Mapping): + self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()}) + + elif isinstance(self.loaders, Sequence): + self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders]) + + # dataloaders are iterable but not sequence + elif isinstance(self.loaders, Iterable): + # only one dataloader, just keep it the same. + pass + else: + raise ValueError(f'Invalid Datatype for loaders: {type(self.loaders).__name__}') + + def __iter__(self) -> Any: + """ + Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. + """ + return CombinedLoaderIterator(self.loaders) + + @staticmethod + def _calc_num_batches(loaders: Any) -> Union[int, float]: + """ + Compute the length (aka the number of batches) of `CombinedLoader`. + + Args: + loaders: a collections of loaders. + + Returns: + length: the minimum length of loaders + + """ + all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) + + if isinstance(all_lengths, (int, float)): + return all_lengths + + else: + return _nested_calc_num_data(all_lengths, min) + + def __len__(self) -> int: + return self._calc_num_batches(self.loaders) + + +class CombinedLoaderIterator(object): + """ + Custom Iterator returning data from multple loaders, and allows sampling in parallel + """ + + def __init__(self, loaders: Any): + """ + + Args: + loaders: the loaders to sample from. Can be all kind of collection + + """ + self.loaders = loaders + self._loader_iters = None + + @property + def loader_iters(self) -> Any: + """ + Get the `_loader_iters` and create one if it is None. + """ + if self._loader_iters is None: + self._loader_iters = self.create_loader_iters(self.loaders) + + return self._loader_iters + + def __iter__(self) -> Any: + return self + + def __next__(self) -> Any: + """ + Fetches the next batch from multiple data loaders + + Returns: + a collections of batch data + + """ + return self.request_next_batch(self.loader_iters) + + @staticmethod + def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: + """ + Return the batch of data from multiple iterators. + + Args: + loader_iters: a collections of iterators + + Returns + Any: a collections of batch data + + """ + return apply_to_collection(loader_iters, Iterator, next) + + @staticmethod + def create_loader_iters( + loaders: Union[Any, Iterator, Sequence, Mapping] + ) -> Union[Any, Iterator, Sequence, Mapping]: + """ + Create and return a collection of iterators from loaders. + + Args: + loaders: a collections of loaders + + Returns + a collections of iterators + + """ + # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences + return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) + + +def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable): + + if isinstance(data, int): + return data + + if isinstance(data, Mapping): + data = list(data.values()) + + if not isinstance(data, Sequence): + raise TypeError(f'Expected data to be int, Sequence or Mapping, but got {type(data).__name__}') + + new_data = [] + + for x in data: + if isinstance(x, (Mapping, Sequence)): + new_data.append(_nested_calc_num_data(x, compute_func)) + else: + new_data.append(x) + + return compute_func(new_data) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0353fae2bff7f0..78d9602f8e5293 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,1080 +1,1151 @@ -import inspect -import os -from argparse import ArgumentParser -from typing import Union, Optional, List, Dict, Tuple, Iterable, Any +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer to automate the training.""" +import logging +import warnings +from itertools import count +from pathlib import Path +from traceback import print_exc +from typing import Any, Dict, Iterable, List, Optional, Union import torch -import torch.distributed as torch_distrib -import torch.multiprocessing as mp from torch.utils.data import DataLoader -from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, ProgressBarBase +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler -from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin -from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin +from pytorch_lightning.plugins import Plugin +from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin +from pytorch_lightning.trainer.configuration_validator import ConfigValidator +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector +from pytorch_lightning.trainer.connectors.data_connector import DataConnector +from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector +from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.model_connector import ModelConnector +from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector +from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector +from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector +from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9 -from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin -from pytorch_lightning.trainer.distrib_parts import ( - TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus) -from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedTrainerAttributes +from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.trainer.training_io import TrainerIOMixin -from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin +from pytorch_lightning.trainer.predict_loop import PredictLoop +from pytorch_lightning.trainer.properties import TrainerProperties +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin -from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin +from pytorch_lightning.tuner.tuning import Tuner +from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities import parsing - - -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True - -try: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.model_helpers import is_overridden + +log = logging.getLogger(__name__) +# warnings to ignore in trainer +warnings.filterwarnings( + 'ignore', message='torch.distributed.reduce_op is deprecated, ' + 'please use torch.distributed.ReduceOp instead' +) class Trainer( - TrainerIOMixin, + TrainerProperties, + TrainerCallbackHookMixin, + TrainerModelHooksMixin, TrainerOptimizersMixin, - TrainerAMPMixin, - TrainerDPMixin, - TrainerDDPMixin, TrainerLoggingMixin, - TrainerModelHooksMixin, TrainerTrainingTricksMixin, TrainerDataLoadingMixin, - TrainerEvaluationLoopMixin, - TrainerTrainLoopMixin, - TrainerCallbackConfigMixin, - TrainerCallbackHookMixin, - TrainerLRFinderMixin, - TrainerDeprecatedAPITillVer0_8, - TrainerDeprecatedAPITillVer0_9, + DeprecatedDistDeviceAttributes, + DeprecatedTrainerAttributes, ): - DEPRECATED_IN_0_8 = ( - 'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', - 'add_row_log_interval', 'nb_sanity_val_steps', 'tng_tqdm_dic', - ) - DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict') + @_defaults_from_env_vars def __init__( - self, - logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, - checkpoint_callback: Union[ModelCheckpoint, bool] = True, - early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, - callbacks: Optional[List[Callback]] = None, - default_root_dir: Optional[str] = None, - gradient_clip_val: float = 0, - process_position: int = 0, - num_nodes: int = 1, - num_processes: int = 1, - gpus: Optional[Union[List[int], str, int]] = None, - auto_select_gpus: bool = False, - num_tpu_cores: Optional[int] = None, - log_gpu_memory: Optional[str] = None, - progress_bar_refresh_rate: int = 1, - overfit_pct: float = 0.0, - track_grad_norm: int = -1, - check_val_every_n_epoch: int = 1, - fast_dev_run: bool = False, - accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, - max_epochs: int = 1000, - min_epochs: int = 1, - max_steps: Optional[int] = None, - min_steps: Optional[int] = None, - train_percent_check: float = 1.0, - val_percent_check: float = 1.0, - test_percent_check: float = 1.0, - val_check_interval: float = 1.0, - log_save_interval: int = 100, - row_log_interval: int = 10, - add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 - distributed_backend: Optional[str] = None, - precision: int = 32, - print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 - weights_summary: Optional[str] = 'full', - weights_save_path: Optional[str] = None, - num_sanity_val_steps: int = 5, - truncated_bptt_steps: Optional[int] = None, - resume_from_checkpoint: Optional[str] = None, - profiler: Optional[BaseProfiler] = None, - benchmark: bool = False, - reload_dataloaders_every_epoch: bool = False, - auto_lr_find: Union[bool, str] = False, - replace_sampler_ddp: bool = True, - progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, - amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 - default_save_path=None, # backward compatible, todo: remove in v0.8.0 - gradient_clip=None, # backward compatible, todo: remove in v0.8.0 - nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 - max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 - min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 - use_amp=None, # backward compatible, todo: remove in v0.9.0 - show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 - nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 - terminate_on_nan: bool = False, - **kwargs + self, + logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, + checkpoint_callback: bool = True, + callbacks: Optional[Union[List[Callback], Callback]] = None, + default_root_dir: Optional[str] = None, + gradient_clip_val: float = 0, + process_position: int = 0, + num_nodes: int = 1, + num_processes: int = 1, + gpus: Optional[Union[List[int], str, int]] = None, + auto_select_gpus: bool = False, + tpu_cores: Optional[Union[List[int], str, int]] = None, + log_gpu_memory: Optional[str] = None, + progress_bar_refresh_rate: Optional[int] = None, + overfit_batches: Union[int, float] = 0.0, + track_grad_norm: Union[int, float, str] = -1, + check_val_every_n_epoch: int = 1, + fast_dev_run: Union[int, bool] = False, + accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, + max_epochs: Optional[int] = None, + min_epochs: Optional[int] = None, + max_steps: Optional[int] = None, + min_steps: Optional[int] = None, + limit_train_batches: Union[int, float] = 1.0, + limit_val_batches: Union[int, float] = 1.0, + limit_test_batches: Union[int, float] = 1.0, + limit_predict_batches: Union[int, float] = 1.0, + val_check_interval: Union[int, float] = 1.0, + flush_logs_every_n_steps: int = 100, + log_every_n_steps: int = 50, + accelerator: Optional[Union[str, Accelerator]] = None, + sync_batchnorm: bool = False, + precision: int = 32, + weights_summary: Optional[str] = 'top', + weights_save_path: Optional[str] = None, + num_sanity_val_steps: int = 2, + truncated_bptt_steps: Optional[int] = None, + resume_from_checkpoint: Optional[Union[Path, str]] = None, + profiler: Optional[Union[BaseProfiler, str]] = None, + benchmark: bool = False, + deterministic: bool = False, + reload_dataloaders_every_epoch: bool = False, + auto_lr_find: Union[bool, str] = False, + replace_sampler_ddp: bool = True, + terminate_on_nan: bool = False, + auto_scale_batch_size: Union[str, bool] = False, + prepare_data_per_node: bool = True, + plugins: Optional[Union[Plugin, str, list]] = None, + amp_backend: str = 'native', + amp_level: str = 'O2', + distributed_backend: Optional[str] = None, + move_metrics_to_cpu: bool = False, + multiple_trainloader_mode: str = 'max_size_cycle', + stochastic_weight_avg: bool = False ): r""" - Customize every aspect of training via flags Args: - logger: Logger (or iterable collection of loggers) for experiment tracking. - checkpoint_callback: Callback for checkpointing. + accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...). + Can also take in an accelerator object for custom hardware. - early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): + accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. - callbacks: Add a list of callbacks. + amp_backend: The mixed precision backend to use ("native" or "apex") - default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed + amp_level: The optimization level to use (O1, O2, etc...). - default_save_path: - .. warning:: .. deprecated:: 0.7.3 + auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder, + trying to optimize initial learning for faster convergence. trainer.tune() method will + set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. + To use a different key set a string instead of True with the key name. - Use `default_root_dir` instead. Will remove 0.9.0. + auto_scale_batch_size: If set to True, will `initially` run a batch size + finder trying to find the largest batch size that fits into memory. + The result will be stored in self.batch_size in the LightningModule. + Additionally, can be set to either `power` that estimates the batch size through + a power search or `binsearch` that estimates the batch size through a binary search. - gradient_clip_val: 0 means don't clip. + auto_select_gpus: If enabled and `gpus` is an integer, pick available + gpus automatically. This is especially useful when + GPUs are configured to be in "exclusive mode", such + that only one process at a time can access them. - gradient_clip: - .. warning:: .. deprecated:: 0.7.0 + benchmark: If true enables cudnn.benchmark. - Use `gradient_clip_val` instead. Will remove 0.9.0. + callbacks: Add a callback or list of callbacks. - process_position: orders the progress bar when running multiple models on same machine. + checkpoint_callback: If ``True``, enable checkpointing. + It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. - num_nodes: number of GPU nodes for distributed training. + check_val_every_n_epoch: Check val every n train epochs. - nb_gpu_nodes: - .. warning:: .. deprecated:: 0.7.0 + default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. + Default: ``os.getcwd()``. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' - Use `num_nodes` instead. Will remove 0.9.0. + deterministic: If true enables cudnn.deterministic. - gpus: Which GPUs to train on. + distributed_backend: deprecated. Please use 'accelerator' - auto_select_gpus: + fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) + of train, val and test to find any bugs (ie: a sort of unit test). - If enabled and `gpus` is an integer, pick available - gpus automatically. This is especially useful when - GPUs are configured to be in "exclusive mode", such - that only one process at a time can access them. + flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). - num_tpu_cores: How many TPU cores to train on (1 or 8). + gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node - log_gpu_memory: None, 'min_max', 'all'. Might slow performance + gradient_clip_val: 0 means don't clip. - show_progress_bar: - .. warning:: .. deprecated:: 0.7.2 + limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches) - Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0. + limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches) - progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. - Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. + limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches) - overfit_pct: How much of training-, validation-, and test dataset to check. + limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches) - track_grad_norm: -1 no tracking. Otherwise tracks that norm + logger: Logger (or iterable collection of loggers) for experiment tracking. - check_val_every_n_epoch: Check val every n train epochs. + log_gpu_memory: None, 'min_max', 'all'. Might slow performance - fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). + log_every_n_steps: How often to log within steps (defaults to every 50 steps). - accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. + prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. + Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data + + process_position: orders the progress bar when running multiple models on same machine. - max_epochs: Stop training once this number of epochs is reached. + progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. + Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means + a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.). + + profiler: To profile individual steps during training and assist in identifying bottlenecks. - max_nb_epochs: - .. warning:: .. deprecated:: 0.7.0 + overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int). - Use `max_epochs` instead. Will remove 0.9.0. + plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. - min_epochs: Force training for at least these many epochs + precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or + TPUs. - min_nb_epochs: - .. warning:: .. deprecated:: 0.7.0 + max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). + If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. - Use `min_epochs` instead. Will remove 0.9.0. + min_epochs: Force training for at least these many epochs. Disabled by default (None). + If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). - train_percent_check: How much of training dataset to check. - - val_percent_check: How much of validation dataset to check. - - test_percent_check: How much of test dataset to check. + num_nodes: number of GPU nodes for distributed training. - val_check_interval: How often within one training epoch to check the validation set + num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" - log_save_interval: Writes logs to disk this often + num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. + Set it to `-1` to run all batches in all validation dataloaders. - row_log_interval: How often to add logging rows (does not write to disk) + reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. - add_row_log_interval: - .. warning:: .. deprecated:: 0.7.0 + replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this + will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for + train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, + you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. - Use `row_log_interval` instead. Will remove 0.9.0. + resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is + no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, + training will start from the beginning of the next epoch. - distributed_backend: The distributed backend to use. + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. - use_amp: - .. warning:: .. deprecated:: 0.7.0 + terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the + end of each training batch, if any of the parameters or the loss are NaN or +/-inf. - Use `precision` instead. Will remove 0.9.0. + tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] - precision: Full precision (32), half precision (16). + track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. - print_nan_grads: - .. warning:: .. deprecated:: 0.7.2 + truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer + sequence. - Has no effect. When detected, NaN grads will be printed automatically. - Will remove 0.9.0. + val_check_interval: How often to check the validation set. Use float to check within a training epoch, + use int to check every n steps (batches). weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir - for checkpoints only. Use this if for whatever reason you need the checkpoints - stored in a different place than the logs written in `default_root_dir`. + for checkpoints only. Use this if for whatever reason you need the checkpoints + stored in a different place than the logs written in `default_root_dir`. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + Defaults to `default_root_dir`. - amp_level: The optimization level to use (O1, O2, etc...). + move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu. + This can save some gpu memory, but can make training slower. Use with attention. - num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. + multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. + In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, + and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets + reload when reaching the minimum length of datasets. - nb_sanity_val_steps: - .. warning:: .. deprecated:: 0.7.0 + stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) + _` - Use `num_sanity_val_steps` instead. Will remove 0.8.0. + """ + super().__init__() + + distributed_backend = distributed_backend or accelerator + + # init connectors + self.dev_debugger = InternalDebugger(self) + self.config_validator = ConfigValidator(self) + self.data_connector = DataConnector(self) + self.optimizer_connector = OptimizerConnector(self) + + self.accelerator_connector = AcceleratorConnector( + num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, + replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins + ) + self.logger_connector = LoggerConnector(self, log_gpu_memory) + self.model_connector = ModelConnector(self) + self.callback_connector = CallbackConnector(self) + self.debugging_connector = DebuggingConnector(self) + self.training_tricks_connector = TrainingTricksConnector(self) + self.profile_connector = ProfilerConnector(self) + self.checkpoint_connector = CheckpointConnector(self) + self.slurm_connector = SLURMConnector(self) + self.tuner = Tuner(self) + self.train_loop = TrainLoop(self, multiple_trainloader_mode) + self.evaluation_loop = EvaluationLoop(self) + self.predict_loop = PredictLoop(self) - truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of + # training state + if weights_summary is not None and weights_summary not in ModelSummary.MODES: + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}" + ) + self.weights_summary = weights_summary + self.shown_warnings = set() + + # init callbacks + # Declare attributes to be set in callback_connector on_trainer_init + self.callback_connector.on_trainer_init( + callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, + weights_save_path, resume_from_checkpoint, stochastic_weight_avg + ) + + # hook + self.on_init_start() - resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. + # init optimizer + lr scheduler related flags + self.optimizer_connector.on_trainer_init() + + # init data flags + self.data_connector.on_trainer_init( + check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node + ) + + # init training tricks + self.training_tricks_connector.on_trainer_init( + gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan + ) + self.train_loop.on_trainer_init( + max_epochs, + min_epochs, + max_steps, + min_steps, + num_sanity_val_steps, + ) + self.evaluation_loop.on_trainer_init() + + # configure tuner + self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) - profiler: To profile individual steps during training and assist in + # configure profiler + self.profile_connector.on_trainer_init(profiler) + + # init logger flags + self.logger_connector.on_trainer_init( + logger, + flush_logs_every_n_steps, + log_every_n_steps, + move_metrics_to_cpu, + ) + + # init debugging flags + self.debugging_connector.on_init_start( + limit_train_batches, + limit_val_batches, + limit_test_batches, + limit_predict_batches, + val_check_interval, + overfit_batches, + fast_dev_run, + ) - reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch + # Callback system + self.on_init_end() - auto_lr_find: If set to True, will `initially` run a learning rate finder, - trying to optimize initial learning for faster convergence. Sets learning - rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. - To use a different key, set a string instead of True with the key name. + def fit( + self, + model: LightningModule, + train_dataloader: Any = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Runs the full optimization routine. - replace_sampler_ddp: Explicitly enables or disables sampler replacement. - If not specified this will toggled automatically ddp is used + Args: + datamodule: A instance of :class:`LightningDataModule`. - benchmark: If true enables cudnn.benchmark. + model: Model to fit. - terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the - end of each training batch, if any of the parameters or the loss are NaN or +/-inf. - """ + train_dataloader: Either a single PyTorch DataLoader or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` - # Init callbacks - self.callbacks = callbacks or [] - self.on_init_start() + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped - # benchmarking - self.benchmark = benchmark - torch.backends.cudnn.benchmark = self.benchmark - - # Transfer params - self.num_nodes = num_nodes - # Backward compatibility, TODO: remove in v0.8.0 - if nb_gpu_nodes is not None: - rank_zero_warn("Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.num_gpu_nodes = nb_gpu_nodes - self.log_gpu_memory = log_gpu_memory - - self.gradient_clip_val = gradient_clip_val - # Backward compatibility, TODO: remove in v0.8.0 - if gradient_clip is not None: - rank_zero_warn("Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.gradient_clip = gradient_clip - - self.check_val_every_n_epoch = check_val_every_n_epoch - self.track_grad_norm = track_grad_norm - self.on_gpu = True if (gpus and torch.cuda.is_available()) else False - - # tpu config - self.on_tpu = num_tpu_cores is not None - self.num_tpu_cores = num_tpu_cores - assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8' - - if num_processes != 1 and distributed_backend != "ddp_cpu": - rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.") - self.num_processes = num_processes - - self.process_position = process_position - self.weights_summary = weights_summary + """ + # we reuse fit for other functions. When already set, it shouldn't be modified. + if not self.state.running: + self.state = TrainerState.FITTING + if self._running_stage is None: + self.training = True - self.max_epochs = max_epochs - # Backward compatibility, TODO: remove in v0.8.0 - if max_nb_epochs is not None: - rank_zero_warn("Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.max_nb_epochs = max_nb_epochs - - self.min_epochs = min_epochs - # Backward compatibility, TODO: remove in v0.8.0 - if min_nb_epochs is not None: - rank_zero_warn("Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.min_nb_epochs = min_nb_epochs - - self.max_steps = max_steps - self.min_steps = min_steps - - self.num_sanity_val_steps = num_sanity_val_steps - # Backward compatibility, TODO: remove in v0.8.0 - if nb_sanity_val_steps is not None: - rank_zero_warn("Argument `nb_sanity_val_steps` has renamed to " - "`num_sanity_val_steps` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - self.nb_sanity_val_steps = nb_sanity_val_steps - - # Backward compatibility, TODO: remove in v0.9.0 - if print_nan_grads: - rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." - " NaN grads will be printed automatically when detected.", DeprecationWarning) - - self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch - - self.auto_lr_find = auto_lr_find - self.replace_sampler_ddp = replace_sampler_ddp - - self.truncated_bptt_steps = truncated_bptt_steps - self.resume_from_checkpoint = resume_from_checkpoint - self.terminate_on_nan = terminate_on_nan - self.shown_warnings = set() + # set local properties on the model + self.model_connector.copy_trainer_model_properties(model) + + # ---------------------------- + # LINK DATA + # ---------------------------- + # setup data, etc... + self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) + + # hook + self.data_connector.prepare_data(model) + self.callback_connector._attach_model_callbacks(model, self) + + # ---------------------------- + # SET UP TRAINING + # ---------------------------- + self.call_hook("on_before_accelerator_backend_setup", model) + self.accelerator.connect(model) + self.accelerator.setup_environment() + self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + self.call_configure_sharded_model(model) # allow user to setup in model sharded environment + self.accelerator.setup(self, model) # note: this sets up self.lightning_module + + # ---------------------------- + # INSPECT THE CORE LOOPS + # ---------------------------- + f""" + Lightning internal flow looks like this: + {Trainer.fit} or {Trainer.test} or {Trainer.predict} || + | || + create accelerator || + | || + {self.dispatch} || + | || LIGHTNING + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || + | || DIRECTION + {self.run_train} || + or {self.run_evaluation} || + or {self.run_predict} || + | || + results \/ + This is used to guide readers to the core loops: train, test, predict. + {self.run_predict} is the simplest to understand, use `Go to Definition` to read it :) + Search for `start_training` or `start_evaluating` or `start_predicting` in + `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. + """ # noqa: W605 + + # ---------------------------- + # TRAIN + # ---------------------------- + # hook + if self.state == TrainerState.FITTING: + self.call_hook("on_fit_start") + + # plugin will setup fitting (e.g. ddp will launch child processes) + self.pre_dispatch() + + # dispatch `start_training` or `start_evaluating` or `start_predicting` + self.dispatch() + + # plugin will finalized fitting (e.g. ddp_spawn will load trained model) + self.post_dispatch() + + # ---------------------------- + # POST-Training CLEAN UP + # ---------------------------- + # hook + if self.state == TrainerState.FITTING: + self.call_hook('on_fit_end') + + # teardown + self.call_teardown_hook(model) + + if self.state != TrainerState.INTERRUPTED: + self.state = TrainerState.FINISHED + self._running_stage = None - self.fast_dev_run = fast_dev_run - if self.fast_dev_run: - self.num_sanity_val_steps = 0 - self.max_epochs = 1 - log.info('Running in fast_dev_run mode: will run a full train,' - ' val and test loop using a single batch') - - # set default save path if user didn't provide one - self.default_root_dir = default_root_dir - - # Backward compatibility, TODO: remove in v0.8.0 - if default_save_path is not None: - self.default_root_dir = default_save_path - - if self.default_root_dir is None: - self.default_root_dir = os.getcwd() - - # training bookeeping - self.total_batch_idx = 0 - self.running_loss = TensorRunningAccum(window_length=20) - self.batch_idx = 0 - self.progress_bar_metrics = {} - self.callback_metrics = {} - self.num_val_batches = 0 - self.num_training_batches = 0 - self.num_test_batches = 0 - self.train_dataloader = None - self.test_dataloaders = None - self.val_dataloaders = None + # return 1 when finished + # used for testing or when we need to know that training succeeded + return self.accelerator.results or 1 - # training state - self.model = None - self.testing = False - self.disable_validation = False - self.lr_schedulers = [] - self.optimizers = None - self.optimizer_frequencies = [] - self.global_step = 0 - self.current_epoch = 0 - self.interrupted = False + def pre_dispatch(self): + self.accelerator.pre_dispatch(self) - # configure logger - self.configure_logger(logger) + # log hyper-parameters + if self.logger is not None: + # save exp to get started (this is where the first experiment logs are written) + self.logger.log_hyperparams(self.lightning_module.hparams_initial) + self.logger.log_graph(self.lightning_module) + self.logger.save() - # configure profiler - if profiler is True: - profiler = SimpleProfiler() - self.profiler = profiler or PassThroughProfiler() + def post_dispatch(self): + self.accelerator.post_dispatch(self) + self.accelerator.teardown() - # configure early stop callback - # creates a default one if none passed in - self.configure_early_stopping(early_stop_callback) + def dispatch(self): + if self.evaluating: + self.accelerator.start_evaluating(self) + elif self.predicting: + self.accelerator.start_predicting(self) + else: + self.accelerator.start_training(self) - # configure checkpoint callback - self.checkpoint_callback = checkpoint_callback - self.weights_save_path = weights_save_path + def run_stage(self): + results = None - # accumulated grads - self.accumulate_grad_batches = accumulate_grad_batches - self.configure_accumulated_gradients(accumulate_grad_batches) + self.profile_connector.setup() - # for gpus allow int, string and gpu list - if auto_select_gpus and isinstance(gpus, int): - self.gpus = pick_multiple_gpus(gpus) + if self.evaluating: + results = self.run_evaluate() + elif self.predicting: + results = self.run_predict() else: - self.gpus = gpus - - self.data_parallel_device_ids = parse_gpu_ids(self.gpus) - self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) - self.root_device = torch.device("cpu") - - # tpu state flags - self.use_tpu = False - self.tpu_local_core_rank = None - self.tpu_global_core_rank = None - - # distributed backend choice - self.distributed_backend = distributed_backend - self.set_distributed_mode(distributed_backend) - - # override dist backend when using tpus - if self.on_tpu: - self.init_tpu() - self.current_tpu_idx = None - - # init flags for SLURM+ddp to work - self.proc_rank = 0 - self.world_size = 1 - self.node_rank = 0 - self.configure_slurm_ddp(self.num_nodes) - - # nvidia setup - self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) - - # backward compatibility - if show_progress_bar is not None: - self.show_progress_bar = show_progress_bar - - self.progress_bar_refresh_rate = progress_bar_refresh_rate - self.progress_bar_callback = None - self.configure_progress_bar() - - # logging - self.log_save_interval = log_save_interval - self.val_check_interval = val_check_interval - - # backward compatibility - if add_row_log_interval is not None: - rank_zero_warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - if not row_log_interval: # in case you did not set the proper value - row_log_interval = add_row_log_interval - self.row_log_interval = row_log_interval - - # how much of the data to use - self.overfit_pct = overfit_pct - self.determine_data_use_amount(train_percent_check, val_percent_check, - test_percent_check, overfit_pct) - - # AMP init - # These are the only lines needed after v0.8.0 - # we wrap the user's forward with autocast and give it back at the end of fit - self.autocast_original_forward = None - self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - self.precision = precision - if self.use_native_amp and self.precision == 16: - self.scaler = torch.cuda.amp.GradScaler() - - # TODO: remove for v0.8.0 - self.amp_level = amp_level - self.init_amp(use_amp) + self.run_train() + return results - # Callback system - self.on_init_end() + def _pre_training_routine(self): + # wait for all to join if on distributed + self.accelerator.barrier("setup_training") - @property - def slurm_job_id(self) -> int: - try: - job_id = os.environ['SLURM_JOB_ID'] - job_id = int(job_id) + # register auto-resubmit when on SLURM + self.slurm_connector.register_slurm_signal_handlers() - # in interactive mode, don't make logs use the same job id - in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash' - if in_slurm_interactive_mode: - job_id = None + # -------------------------- + # Pre-train + # -------------------------- + # on pretrain routine start + ref_model = self.lightning_module - except Exception: - job_id = None - return job_id + self.on_pretrain_routine_start() + ref_model.on_pretrain_routine_start() - @classmethod - def default_attributes(cls): - init_signature = inspect.signature(Trainer) + # print model summary + if self.is_global_zero and self.weights_summary is not None and not self.testing: + ref_model.summarize(mode=self.weights_summary) - args = {} - for param_name in init_signature.parameters: - value = init_signature.parameters[param_name].default - args[param_name] = value + # restore training and model before hpc is called + self.checkpoint_connector.restore_weights() - return args + # on pretrain routine end + self.on_pretrain_routine_end() + ref_model.on_pretrain_routine_end() - @classmethod - def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. + def run_train(self) -> None: - Returns: - List with tuples of 3 values: - (argument name, set with argument types, argument default value). - - Examples: - >>> args = Trainer.get_init_arguments_and_types() - >>> import pprint - >>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - [('accumulate_grad_batches', - (, typing.Dict[int, int], typing.List[list]), - 1), - ... - ('callbacks', - (typing.List[pytorch_lightning.callbacks.base.Callback], - ), - None), - ('check_val_every_n_epoch', (,), 1), - ... - ('max_epochs', (,), 1000), - ... - ('precision', (,), 32), - ('print_nan_grads', (,), False), - ('process_position', (,), 0), - ('profiler', - (, - ), - None), - ... - """ - trainer_default_params = inspect.signature(cls).parameters - name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default - try: - arg_types = tuple(arg_type.__args__) - except AttributeError: - arg_types = (arg_type,) - - name_type_default.append((arg, arg_types, arg_default)) - - return name_type_default - - @classmethod - def get_deprecated_arg_names(cls) -> List: - """Returns a list with deprecated Trainer arguments.""" - depr_arg_names = [] - for name, val in cls.__dict__.items(): - if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): - depr_arg_names.extend(val) - return depr_arg_names - - @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `Trainer` attributes. + self._pre_training_routine() - Args: - parent_parser: - The custom cli arguments parser, which will be extended by - the Trainer default arguments. + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() - Only arguments of the allowed types (str, float, int, bool) will - extend the `parent_parser`. - """ - parser = ArgumentParser(parents=[parent_parser], add_help=False, ) + self.run_sanity_check(self.lightning_module) - blacklist = ['kwargs'] - depr_arg_names = cls.get_deprecated_arg_names() + blacklist + self.checkpoint_connector.has_trained = False - allowed_types = (str, float, int, bool) + # enable train mode + model = self.lightning_module + model.train() + torch.set_grad_enabled(True) - # TODO: get "help" from docstring :) - for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types() - if at[0] not in depr_arg_names): + # reload data when needed + self.train_loop.reset_train_val_dataloaders(model) - for allowed_type in (at for at in allowed_types if at in arg_types): - if allowed_type is bool: - def allowed_type(x): - return bool(parsing.strtobool(x)) + # hook + self.train_loop.on_train_start() - if arg == 'gpus': - allowed_type = Trainer.allowed_type - arg_default = Trainer.arg_default + try: + if self.train_loop.should_skip_training(): + return + # run all epochs + epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count(self.current_epoch) + for epoch in epochs: + + # hook + self.train_loop.on_train_epoch_start(epoch) + + with self.profiler.profile("run_training_epoch"): + # run train epoch + self.train_loop.run_training_epoch() + + if self.max_steps and self.max_steps <= self.global_step: + return + + # early stopping + met_min_epochs = (epoch >= self.min_epochs - 1) if self.min_epochs else True + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + + if self.should_stop: + if met_min_epochs and met_min_steps: + return + else: + log.info( + 'Trainer was signaled to stop but required minimum epochs' + f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' + ' not been met. Training will continue...' + ) + + # hook + self.train_loop.on_train_end() + + except KeyboardInterrupt: + rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') + # user could press Ctrl+c many times... only shutdown once + if not self.interrupted: + self.state = TrainerState.INTERRUPTED + self.on_keyboard_interrupt() + except (RuntimeError, AssertionError): + # if an exception is raised, the finally block is executed and can hide the actual exception + # that was initially raised if `on_train_end` also raises an exception. we want to avoid that + # for assertions and other runtime errors so we aren't misled while debugging + print_exc() + finally: + # hook + self.train_loop.on_train_end() + + def run_evaluation(self, on_epoch=False): + if not (self.evaluating or self.sanity_checking): + rank_zero_warn( + f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}." + " This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning + ) + self.validating = True - parser.add_argument( - f'--{arg}', - default=arg_default, - type=allowed_type, - dest=arg, - help='autogenerated by pl.Trainer' - ) - break + # reset cached results + self.logger_connector.reset() - return parser + # prepare dataloaders + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() - def allowed_type(x): - if ',' in x: - return str(x) - else: - return int(x) + # check if we want to skip this evaluation + if self.evaluation_loop.should_skip_evaluation(max_batches): + return [], [] - def arg_default(x): - if ',' in x: - return str(x) - else: - return int(x) + # enable eval mode + no grads + self.evaluation_loop.on_evaluation_model_eval() + # ref model + model = self.lightning_module + model.zero_grad() + torch.set_grad_enabled(False) - @classmethod - def from_argparse_args(cls, args, **kwargs): + # hook + self.evaluation_loop.on_evaluation_start() - params = vars(args) - params.update(**kwargs) + # set up the eval loop + self.evaluation_loop.setup(model, max_batches, dataloaders) - return cls(**params) + # hook + self.evaluation_loop.on_evaluation_epoch_start() - @property - def num_gpus(self) -> int: - gpus = self.data_parallel_device_ids - if gpus is None: - return 0 - return len(gpus) + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + # bookkeeping + dl_outputs = [] + dataloader = self.accelerator.process_dataloader(dataloader) + dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] - @property - def data_parallel(self) -> bool: - return self.use_dp or self.use_ddp or self.use_ddp2 + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue - @property - def progress_bar_dict(self) -> dict: - """ Read-only for progress bar metrics. """ - ref_model = self.model if not self.data_parallel else self.model.module - return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break - # ----------------------------- - # MODEL TRAINING - # ----------------------------- - def fit( - self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None - ): - r""" - Runs the full optimization routine. + # hook + self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - Args: - model: Model to fit. + # lightning module methods + with self.profiler.profile("evaluation_step_and_end"): + output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) + output = self.evaluation_loop.evaluation_step_end(output) - train_dataloader: A Pytorch - DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + # hook + store predictions + self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) - val_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + # log batch metrics + self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx) - Example:: + # track epoch level outputs + dl_outputs = self.track_output_for_epoch_end(dl_outputs, output) - # Option 1, - # Define the train_dataloader() and val_dataloader() fxs - # in the lightningModule - # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY - trainer = Trainer() - model = LightningModule() - trainer.fit(model) + # store batch level output per dataloader + self.evaluation_loop.outputs.append(dl_outputs) - # Option 2 - # in production cases we might want to pass different datasets to the same model - # Recommended for PRODUCTION SYSTEMS - train, val = DataLoader(...), DataLoader(...) - trainer = Trainer() - model = LightningModule() - trainer.fit(model, train_dataloader=train, val_dataloader=val) + # lightning module method + deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() - # Option 1 & 2 can be mixed, for example the training set can be - # defined as part of the model, and validation can then be feed to .fit() + # hook + self.evaluation_loop.on_evaluation_epoch_end() - """ - # bind logger and other properties - model.logger = self.logger - self.copy_trainer_model_properties(model) - - # clean hparams - if hasattr(model, 'hparams'): - parsing.clean_namespace(model.hparams) - - # set up the passed in dataloaders (if needed) - self.__attach_dataloaders(model, train_dataloader, val_dataloaders) - - # check that model is configured correctly - self.check_model_configuration(model) - - # download the data and do whatever transforms we need - # do before any spawn calls so that the model can assign properties - # only on proc 0 because no spawn has happened yet - model.prepare_data() - - # Run learning rate finder: - if self.auto_lr_find: - self._run_lr_finder_internally(model) - - # route to appropriate start method - # when using multi-node or DDP within a node start each module in a separate process - if self.use_ddp2: - task = int(os.environ['SLURM_LOCALID']) - self.ddp_train(task, model) - - elif self.use_ddp: - if self.is_slurm_managing_tasks: - task = int(os.environ['SLURM_LOCALID']) - self.ddp_train(task, model) - else: - self.__set_random_port() - # track for predict - self.model = model - # train - mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,)) - # load weights if not interrupted - if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE'): - self.load_spawn_weights(model) - self.model = model - - # 1 gpu or dp option triggers training using DP module - # easier to avoid NCCL issues - elif self.use_dp: - self.dp_train(model) - - elif self.use_horovod: - self.horovod_train(model) - - elif self.single_gpu: - self.single_gpu_train(model) - - elif self.use_tpu: # pragma: no-cover - log.info(f'training on {self.num_tpu_cores} TPU cores') - - # COLAB_GPU is an env var available by default in Colab environments. - start_method = 'fork' if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') else 'spawn' - - # track for predict - self.model = model - - # train - xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method) - - # load weights if not interrupted - self.load_spawn_weights(model) - self.model = model - - # ON CPU - else: - # run through amp wrapper - if self.use_amp: - raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') + # update epoch-level lr_schedulers + if on_epoch: + self.optimizer_connector.update_learning_rates(interval='epoch') - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) + # hook + self.evaluation_loop.on_evaluation_end() - self.run_pretrain_routine(model) + # log epoch metrics + eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end() - # return 1 when finished - # used for testing or when we need to know that training succeeded - return 1 + # save predictions to disk + self.evaluation_loop.predictions.to_disk() - def __set_random_port(self): - """ - When running DDP NOT managed by SLURM, the ports might collide - :return: - """ - try: - default_port = os.environ['MASTER_PORT'] - except Exception: - import random - default_port = random.randint(10000, 19000) - os.environ['MASTER_PORT'] = str(default_port) + # enable train mode again + self.evaluation_loop.on_evaluation_model_train() - def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): - # when dataloader is passed via fit, patch the train_dataloader - # functions to overwrite with these implementations - if train_dataloader is not None: - model.train_dataloader = _PatchDataLoader(train_dataloader) + torch.set_grad_enabled(True) - if val_dataloaders is not None: - model.val_dataloader = _PatchDataLoader(val_dataloaders) + return eval_loop_results, deprecated_eval_results - if test_dataloaders is not None: - model.test_dataloader = _PatchDataLoader(test_dataloaders) + def track_output_for_epoch_end(self, outputs, output): + if output is not None: + if isinstance(output, Result): + output = output.detach() + if self.move_metrics_to_cpu: + output = output.cpu() + elif isinstance(output, dict): + output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu) + elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu: + output = output.cpu() + outputs.append(output) + return outputs - def run_pretrain_routine(self, model: LightningModule): - """Sanity check a few things before starting actual training. + def run_evaluate(self): + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() - Args: - model: The model to run sanity test on. - """ - ref_model = model - if self.data_parallel: - ref_model = model.module + assert self.evaluating - # give model convenience properties - ref_model.trainer = self + with self.profiler.profile(f"run_{self._running_stage}_evaluation"): + eval_loop_results, _ = self.run_evaluation() - # set local properties on the model - self.copy_trainer_model_properties(ref_model) + if len(eval_loop_results) == 0: + return 1 - # log hyper-parameters - if self.logger is not None: - # save exp to get started - if hasattr(ref_model, "hparams"): - self.logger.log_hyperparams(ref_model.hparams) + # remove the tensors from the eval results + for i, result in enumerate(eval_loop_results): + if isinstance(result, dict): + for k, v in result.items(): + if isinstance(v, torch.Tensor): + result[k] = v.cpu().item() - self.logger.save() + return eval_loop_results - if self.use_ddp or self.use_ddp2: - torch_distrib.barrier() + def run_predict(self): + self.predict_loop.on_predict_start() - # wait for all models to restore weights - if self.on_tpu and XLA_AVAILABLE: - # wait for all processes to catch up - torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine") + # prepare dataloaders + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() - elif self.use_horovod: - # wait for all processes to catch up - hvd.join() + # check if we want to skip this evaluation + if self.predict_loop.should_skip_predict(max_batches): + return [] - # register auto-resubmit when on SLURM - self.register_slurm_signal_handlers() + # ref model + model = self.lightning_module - # print model summary - # TODO: remove self.testing condition because model.summarize() is wiping out the weights - if self.proc_rank == 0 and self.weights_summary is not None and not self.testing: - if self.weights_summary in ['full', 'top']: - ref_model.summarize(mode=self.weights_summary) - else: - raise MisconfigurationException("weights_summary can be None, 'full' or 'top'") - - # track model now. - # if cluster resets state, the model will update with the saved weights - self.model = model - - # set up checkpoint callback - self.configure_checkpoint_callback() - - # restore training and model before hpc call - self.restore_weights(model) - - # when testing requested only run test and return - if self.testing: - # only load test dataloader for testing - # self.reset_test_dataloader(ref_model) - self.run_evaluation(test_mode=True) - return + # enable eval mode + no grads + self.predict_loop.on_predict_model_eval() + model.zero_grad() + torch.set_grad_enabled(False) + + # set up the eval loop + self.predict_loop.setup(model, max_batches, dataloaders) + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + dataloader = self.accelerator.process_dataloader(dataloader) + dl_max_batches = self.predict_loop.max_batches[dataloader_idx] + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue - # check if we should run validation during training - self.disable_validation = not (self.is_overriden('validation_step') and self.val_percent_check > 0) \ - and not self.fast_dev_run + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # lightning module methods + with self.profiler.profile("predict_step"): + self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) + + results = self.predict_loop.on_predict_epoch_end() + self.predict_loop.on_predict_end() + + # re-enable grads + torch.set_grad_enabled(True) + + return results + + def run_sanity_check(self, ref_model): + using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) + should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val - if not self.disable_validation and self.num_sanity_val_steps > 0: - self.reset_val_dataloader(ref_model) + if should_sanity_check: + stage = self._running_stage + self.sanity_checking = True # hook and callback - ref_model.on_sanity_check_start() self.on_sanity_check_start() - eval_results = self._evaluate(model, - self.val_dataloaders, - self.num_sanity_val_steps, - False) - _, _, _, callback_metrics, _ = self.process_output(eval_results) + # run eval step + _, eval_results = self.run_evaluation() self.on_sanity_check_end() - # verify that early stop has conditioned on a metric that exists - if self.enable_early_stop: - self.early_stop_callback._validate_condition_metric(callback_metrics) + self._running_stage = stage + + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Perform one evaluation epoch over the validation set. + + Args: + model: The model to validate. + + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + + verbose: If True, prints the validation results. + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + + self.state = TrainerState.VALIDATING + self.validating = True + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' + ) + + model_provided = model is not None + model = model or self.lightning_module + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) - # clear cache before training - if self.on_gpu: - torch.cuda.empty_cache() + if not model_provided: + self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) - # CORE TRAINING LOOP - self.train() + # run validate + results = self.fit(model) + + assert self.state.stopped + self.validating = False + + return results def test( - self, - model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None + self, + model: Optional[LightningModule] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, ): r""" - - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. Args: model: The model to test. - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. - Example:: + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. - # Option 1 - # run test after fitting - test = DataLoader(...) - trainer = Trainer() - model = LightningModule() + verbose: If True, prints the test results. - trainer.fit(model) - trainer.test(test_dataloaders=test) + datamodule: A instance of :class:`LightningDataModule`. - # Option 2 - # run test from a loaded model - test = DataLoader(...) - model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') - trainer = Trainer() - trainer.test(model, test_dataloaders=test) + Returns: + Returns a list of dictionaries, one for each test dataloader containing their respective metrics. """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + self.state = TrainerState.TESTING self.testing = True - if test_dataloaders is not None: - if model: - self.__attach_dataloaders(model, test_dataloaders=test_dataloaders) - else: - self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders) - - # give proper warnings if user only passed in loader without hooks - self.check_testing_model_configuration(model if model else self.model) - - if model is not None: - self.model = model - self.fit(model) - elif self.use_ddp or self.use_tpu: # pragma: no-cover - # attempt to load weights from a spawn - path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') - test_model = self.model - if os.path.exists(path): - test_model = self.load_spawn_weights(self.model) - - self.fit(test_model) - else: - self.run_evaluation(test_mode=True) + # If you supply a datamodule you can't supply test_dataloaders + if test_dataloaders and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') + + model_provided = model is not None + model = model or self.lightning_module + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + + if not model_provided: + self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + + # run test + results = self.fit(model) + + assert self.state.stopped self.testing = False - def check_model_configuration(self, model: LightningModule): + return results + + def __load_ckpt_weights( + self, + model, + ckpt_path: Optional[str] = None, + ) -> Optional[str]: + if ckpt_path is None: + return + + fn = self.state.value + + if ckpt_path == 'best': + # if user requests the best checkpoint but we don't have it, error + if not self.checkpoint_callback.best_model_path: + if self.fast_dev_run: + raise MisconfigurationException( + f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do' + f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' + ) + raise MisconfigurationException( + f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' + ) + # load best weights + ckpt_path = self.checkpoint_callback.best_model_path + + if not ckpt_path: + raise MisconfigurationException( + f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' + ) + + # only one process running at this point for TPUs, as spawn isn't triggered yet + if self._device_type != DeviceType.TPU: + self.training_type_plugin.barrier() + + ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt['state_dict']) + + return ckpt_path + + def predict( + self, + model: Optional[LightningModule] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ): r""" - Checks that the model is configured correctly before training is started. + + Separates from fit to make sure you never run on your predictions set until you want to. + + This will call the model forward function to compute predictions. Args: - model: The model to test. + model: The model to predict on. + dataloaders: Either a single + Pytorch Dataloader or a list of them, specifying inference samples. + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ - # Check training_step, train_dataloader, configure_optimizer methods - if not self.is_overriden('training_step', model): - raise MisconfigurationException( - 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' - ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') - if not self.is_overriden('train_dataloader', model): - raise MisconfigurationException( - 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' - ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') + # -------------------- + # SETUP HOOK + # -------------------- + # If you supply a datamodule you can't supply dataloaders - if not self.is_overriden('configure_optimizers', model): - raise MisconfigurationException( - 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' - ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') - - # Check val_dataloader, validation_step and validation_epoch_end - if self.is_overriden('val_dataloader', model): - if not self.is_overriden('validation_step', model): - raise MisconfigurationException('You have passed in a `val_dataloader()`' - ' but have not defined `validation_step()`.') - else: - if not self.is_overriden('validation_epoch_end', model): - rank_zero_warn( - 'You have defined a `val_dataloader()` and have defined a `validation_step()`,' - ' you may also want to define `validation_epoch_end()` for accumulating stats.', - RuntimeWarning - ) - else: - if self.is_overriden('validation_step', model): - raise MisconfigurationException('You have defined `validation_step()`,' - ' but have not passed in a val_dataloader().') - - # Check test_dataloader, test_step and test_epoch_end - if self.is_overriden('test_dataloader', model): - if not self.is_overriden('test_step', model): - raise MisconfigurationException('You have passed in a `test_dataloader()`' - ' but have not defined `test_step()`.') - else: - if not self.is_overriden('test_epoch_end', model): - rank_zero_warn( - 'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to' - ' define `test_epoch_end()` for accumulating stats.', RuntimeWarning - ) + model = model or self.lightning_module - def check_testing_model_configuration(self, model: LightningModule): + self.state = TrainerState.PREDICTING + self.predicting = True - has_test_step = self.is_overriden('test_step', model) - has_test_epoch_end = self.is_overriden('test_epoch_end', model) - gave_test_loader = hasattr(model, 'test_dataloader') and model.test_dataloader() + if dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' + ) - if gave_test_loader and not has_test_step: - raise MisconfigurationException('You passed in a `test_dataloader` but did not implement `test_step()`') + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule) + # Attach dataloaders (if given) + self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - if has_test_step and not gave_test_loader: - raise MisconfigurationException('You defined `test_step()` but did not implement' - ' `test_dataloader` nor passed in `.fit(test_dataloaders`.') + results = self.fit(model) - if has_test_step and gave_test_loader and not has_test_epoch_end: - rank_zero_warn( - 'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to' - ' define `test_epoch_end()` for accumulating stats.', RuntimeWarning - ) + assert self.state.stopped + self.predicting = False + return results -class _PatchDataLoader(object): - r""" - Callable object for patching dataloaders passed into trainer.fit(). - Use this class to override model.*_dataloader() and be pickle-compatible. + def tune( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Runs routines to tune hyperparameters before training. - Args: - dataloader: Dataloader object to return when called. + Args: + datamodule: A instance of :class:`LightningDataModule`. - """ + model: Model to tune. - def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): - self.dataloader = dataloader + train_dataloader: A Pytorch DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. - # cannot pickle __code__ so cannot verify if PatchDataloader - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - self.patch_loader_code = str(self.__call__.__code__) + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped - def __call__(self) -> Union[List[DataLoader], DataLoader]: - return self.dataloader + """ + self.state = TrainerState.TUNING + self.tuning = True + + self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) + + assert self.state.stopped + self.tuning = False + + def call_setup_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = self._setup_state + + if self.datamodule is not None: + called = getattr(self.datamodule, f'has_setup_{state}') + if not called: + self.datamodule.setup(stage=state) + + self.setup(model, stage=state) + model.setup(stage=state) + + def call_configure_sharded_model(self, model: LightningModule) -> None: + # Call configure sharded model hook if accelerator requests. In some cases + # we will not call the hook; the hook has initialized the sharded model for example. + if self.accelerator.call_configure_sharded_model_hook: + with self.accelerator.model_sharded_context(): + model.configure_sharded_model() + self.configure_sharded_model(model) + self.accelerator.call_configure_sharded_model_hook = False + + def call_teardown_hook(self, model: LightningModule) -> None: + state = self._teardown_state + + if self.datamodule is not None: + called = getattr(self.datamodule, f'has_teardown_{state}') + if not called: + self.datamodule.teardown(stage=state) + + self.profiler.teardown(stage=state) + self.teardown(stage=state) + model.teardown(stage=state) + + def _reset_result_and_set_hook_fx_name(self, hook_name): + # on_before_zero_grad is called within training_step + if "batch_start" in hook_name or "on_before_zero_grad" in hook_name: + return True + model_ref = self.lightning_module + if model_ref is not None: + # used to track current hook name called + model_ref._results = Result() + model_ref._current_hook_fx_name = hook_name + return False + + def _cache_logged_metrics(self): + model_ref = self.lightning_module + if model_ref is not None: + # capture logging for this hook + self.logger_connector.cache_logged_metrics() + + def call_hook(self, hook_name, *args, **kwargs): + # set hook_name to model + reset Result obj + skip = self._reset_result_and_set_hook_fx_name(hook_name) + + # always profile hooks + with self.profiler.profile(hook_name): + + # first call trainer hook + if hasattr(self, hook_name): + trainer_hook = getattr(self, hook_name) + trainer_hook(*args, **kwargs) + + # next call hook in lightningModule + output = None + model_ref = self.lightning_module + if is_overridden(hook_name, model_ref): + hook_fx = getattr(model_ref, hook_name) + output = hook_fx(*args, **kwargs) + + # if the PL module doesn't have the hook then call the accelerator + # used to auto-reduce things for the user with Results obj + elif hasattr(self.accelerator, hook_name): + accelerator_hook = getattr(self.accelerator, hook_name) + output = accelerator_hook(*args, **kwargs) + + if not skip: + self._cache_logged_metrics() + return output diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py deleted file mode 100644 index 4f474b761e94fb..00000000000000 --- a/pytorch_lightning/trainer/training_io.py +++ /dev/null @@ -1,500 +0,0 @@ -""" -Lightning can automate saving and loading checkpoints -===================================================== - -Checkpointing is enabled by default to the current working directory. -To change the checkpoint path pass in:: - - Trainer(default_root_dir='/your/path/to/save/checkpoints') - - -To modify the behavior of checkpointing pass in your own callback. - -.. code-block:: python - - from pytorch_lightning.callbacks import ModelCheckpoint - - # DEFAULTS used by the Trainer - checkpoint_callback = ModelCheckpoint( - filepath=os.getcwd(), - save_top_k=1, - verbose=True, - monitor='val_loss', - mode='min', - prefix='' - ) - - trainer = Trainer(checkpoint_callback=checkpoint_callback) - - -Restoring training session --------------------------- - -You might want to not only load a model but also continue training it. Use this method to -restore the trainer state as well. This will continue from the epoch and global step you last left off. -However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). - -Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint. - -.. code-block:: python - - from pytorch_lightning import Trainer - - trainer = Trainer( - resume_from_checkpoint=PATH - ) - - # this fit call loads model weights and trainer state - # the trainer continues seamlessly from where you left off - # without having to do anything else. - trainer.fit(model) - - -The trainer restores: - -- global_step -- current_epoch -- All optimizers -- All lr_schedulers -- Model weights - -You can even change the logic of your model as long as the weights and "architecture" of -the system isn't different. If you add a layer, for instance, it might not work. - -At a rough level, here's what happens inside Trainer :py:mod:`pytorch_lightning.base_module.model_saving.py`: - -.. code-block:: python - - self.global_step = checkpoint['global_step'] - self.current_epoch = checkpoint['epoch'] - - # restore the optimizers - optimizer_states = checkpoint['optimizer_states'] - for optimizer, opt_state in zip(self.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - - # restore the lr schedulers - lr_schedulers = checkpoint['lr_schedulers'] - for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): - scheduler['scheduler'].load_state_dict(lrs_state) - - # uses the model you passed into trainer - model.load_state_dict(checkpoint['state_dict']) - -""" - -import os -import re -import signal -from abc import ABC -from argparse import Namespace -from subprocess import call -from typing import Union - -import torch -import torch.distributed as torch_distrib - -from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.overrides.data_parallel import ( - LightningDistributedDataParallel, - LightningDataParallel, -) -from pytorch_lightning.utilities import rank_zero_warn, parsing - -try: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - - -class TrainerIOMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - model: LightningModule - on_gpu: bool - root_gpu: ... - resume_from_checkpoint: ... - use_ddp: bool - use_ddp2: bool - use_horovod: bool - checkpoint_callback: ... - proc_rank: int - weights_save_path: str - logger: Union[LightningLoggerBase, bool] - early_stop_callback: ... - lr_schedulers: ... - optimizers: ... - on_tpu: bool - num_training_batches: int - accumulate_grad_batches: int - - def get_model(self): - is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, - LightningDataParallel)) - model = self.model.module if is_dp_module else self.model - return model - - # -------------------- - # CHECK-POINTING - # -------------------- - def restore_weights(self, model: LightningModule): - """ - We attempt to restore weights in this order: - 1. HPC weights. - 2. if no HPC weights restore checkpoint_path weights - 3. otherwise don't restore weights - """ - # clear cache before restore - if self.on_gpu: - torch.cuda.empty_cache() - - # if script called from hpc resubmit, load weights - did_restore_hpc_weights = self.restore_hpc_weights_if_needed(model) - - # clear cache after restore - if self.on_gpu: - torch.cuda.empty_cache() - - if not did_restore_hpc_weights: - if self.resume_from_checkpoint is not None: - self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) - - # wait for all models to restore weights - if self.use_ddp or self.use_ddp2: - # wait for all processes to catch up - torch_distrib.barrier() - - # wait for all models to restore weights - if self.on_tpu and XLA_AVAILABLE: - # wait for all processes to catch up - torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights") - - elif self.use_horovod: - # wait for all processes to catch up - hvd.join() - - # clear cache after restore - if self.on_gpu: - torch.cuda.empty_cache() - - # -------------------- - # HPC SIGNAL HANDLING - # -------------------- - def register_slurm_signal_handlers(self): - # see if we're using slurm (not interactive) - on_slurm = False - try: - job_name = os.environ['SLURM_JOB_NAME'] - if job_name != 'bash': - on_slurm = True - except Exception as e: - pass - - if on_slurm: - log.info('Set SLURM handle signals.') - signal.signal(signal.SIGUSR1, self.sig_handler) - signal.signal(signal.SIGTERM, self.term_handler) - - def sig_handler(self, signum, frame): # pragma: no-cover - if self.proc_rank == 0: - # save weights - log.info('handling SIGUSR1') - self.hpc_save(self.weights_save_path, self.logger) - - # find job id - job_id = os.environ['SLURM_JOB_ID'] - cmd = 'scontrol requeue {}'.format(job_id) - - # requeue job - log.info(f'requeing job {job_id}...') - result = call(cmd, shell=True) - - # print result text - if result == 0: - log.info(f'requeued exp {job_id}') - else: - log.warning('requeue failed...') - - # close experiment to avoid issues - self.logger.close() - - def term_handler(self, signum, frame): - # save - log.info("bypassing sigterm") - - # -------------------- - # MODEL SAVE CHECKPOINT - # -------------------- - def _atomic_save(self, checkpoint, filepath: str): - """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. - - This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once - saving is finished. - - Args: - checkpoint: The object to save. - Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` - accepts. - filepath: The path to which the checkpoint will be saved. - This points to the file that the checkpoint will be stored in. - """ - tmp_path = str(filepath) + ".part" - torch.save(checkpoint, tmp_path) - os.replace(tmp_path, filepath) - - def save_checkpoint(self, filepath): - checkpoint = self.dump_checkpoint() - - if self.proc_rank == 0: - # do the actual save - try: - self._atomic_save(checkpoint, filepath) - except AttributeError as e: - if 'hparams' in checkpoint: - del checkpoint['hparams'] - rank_zero_warn('warning, `hparams` dropped from checkpoint.' - f' An attribute is not picklable {e}') - - self._atomic_save(checkpoint, filepath) - - def restore(self, checkpoint_path: str, on_gpu: bool): - """ - Restore training state from checkpoint. - Also restores all training state like: - - epoch - - callbacks - - schedulers - - optimizer - """ - - # if on_gpu: - # checkpoint = torch.load(checkpoint_path) - # else: - # load on CPU first - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - - # load model state - model = self.get_model() - - # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict']) - - # give model a chance to load something - model.on_load_checkpoint(checkpoint) - - if on_gpu: - model.cuda(self.root_gpu) - - # restore amp scaling - if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: - self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - - # load training state (affects trainer only) - self.restore_training_state(checkpoint) - - def dump_checkpoint(self): - checkpoint = { - 'epoch': self.current_epoch + 1, - 'global_step': self.global_step + 1, - } - - if self.checkpoint_callback is not None and self.checkpoint_callback is not False: - checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best - - if self.early_stop_callback is not None and self.checkpoint_callback is not False: - checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait - checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience - - # save optimizers - optimizer_states = [] - for i, optimizer in enumerate(self.optimizers): - optimizer_states.append(optimizer.state_dict()) - - checkpoint['optimizer_states'] = optimizer_states - - # save lr schedulers - lr_schedulers = [] - for scheduler in self.lr_schedulers: - lr_schedulers.append(scheduler['scheduler'].state_dict()) - - checkpoint['lr_schedulers'] = lr_schedulers - - # add the hparams and state_dict from the model - model = self.get_model() - - checkpoint['state_dict'] = model.state_dict() - - # restore native amp scaling - if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: - checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() - - if hasattr(model, "hparams"): - parsing.clean_namespace(model.hparams) - is_namespace = isinstance(model.hparams, Namespace) - checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams - checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict' - else: - rank_zero_warn( - "Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters." - ) - - # give the model a chance to add a few things - model.on_save_checkpoint(checkpoint) - - return checkpoint - - # -------------------- - # HPC IO - # -------------------- - def restore_hpc_weights_if_needed(self, model: LightningModule): - """If there is a set of hpc weights, use as signal to restore model.""" - did_restore = False - - # look for hpc weights - folderpath = self.weights_save_path - if os.path.exists(folderpath): - files = os.listdir(folderpath) - hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] - - # if hpc weights exist restore model - if len(hpc_weight_paths) > 0: - self.hpc_load(folderpath, self.on_gpu) - did_restore = True - return did_restore - - def restore_training_state(self, checkpoint): - """ - Restore trainer state. - Model will get its change to update - :param checkpoint: - :return: - """ - if self.checkpoint_callback is not None and self.checkpoint_callback is not False: - self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] - - if self.early_stop_callback is not None and self.early_stop_callback is not False: - self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] - self.early_stop_callback.patience = checkpoint['early_stop_callback_patience'] - - self.global_step = checkpoint['global_step'] - self.current_epoch = checkpoint['epoch'] - - # Division deals with global step stepping once per accumulated batch - # Inequality deals with different global step for odd vs even num_training_batches - n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches - expected_steps = self.num_training_batches / n_accum - if self.num_training_batches != 0 and self.global_step % expected_steps > 1: - rank_zero_warn( - "You're resuming from a checkpoint that ended mid-epoch. " - "This can cause unreliable results if further training is done, " - "consider using an end of epoch checkpoint. " - ) - - # restore the optimizers - optimizer_states = checkpoint['optimizer_states'] - for optimizer, opt_state in zip(self.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - - # move optimizer to GPU 1 weight at a time - # avoids OOM - if self.root_gpu is not None: - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.cuda(self.root_gpu) - - # restore the lr schedulers - lr_schedulers = checkpoint['lr_schedulers'] - for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): - scheduler['scheduler'].load_state_dict(lrs_state) - - # ---------------------------------- - # PRIVATE OPS - # ---------------------------------- - def hpc_save(self, folderpath: str, logger): - # make sure the checkpoint folder exists - os.makedirs(folderpath, exist_ok=True) - - # save logger to make sure we get all the metrics - logger.save() - - ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 - - if not os.path.exists(folderpath): - os.makedirs(folderpath, exist_ok=True) - filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') - - # give model a chance to do something on hpc_save - model = self.get_model() - checkpoint = self.dump_checkpoint() - - model.on_hpc_save(checkpoint) - - # do the actual save - # TODO: fix for anything with multiprocess DP, DDP, DDP2 - try: - self._atomic_save(checkpoint, filepath) - except AttributeError as e: - if 'hparams' in checkpoint: - del checkpoint['hparams'] - rank_zero_warn('warning, `hparams` dropped from checkpoint.' - f' An attribute is not picklable {e}') - - self._atomic_save(checkpoint, filepath) - - return filepath - - def hpc_load(self, folderpath, on_gpu): - filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath)) - - # load on CPU first - checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) - - # load model state - model = self.get_model() - - # load the state_dict on the model automatically - model.load_state_dict(checkpoint['state_dict']) - - # restore amp scaling - if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: - self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - - if self.root_gpu is not None: - model.cuda(self.root_gpu) - - # load training state (affects trainer only) - self.restore_training_state(checkpoint) - - # call model hook - model.on_hpc_load(checkpoint) - - log.info(f'restored hpc model from: {filepath}') - - def max_ckpt_in_folder(self, path, name_key='ckpt_'): - files = os.listdir(path) - files = [x for x in files if name_key in x] - if len(files) == 0: - return 0 - - ckpt_vs = [] - for name in files: - name = name.split(name_key)[-1] - name = re.sub('[^0-9]', '', name) - ckpt_vs.append(int(name)) - - return max(ckpt_vs) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b2ce8599bc9a06..696f14742935ce 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1,806 +1,871 @@ -""" -The lightning training loop handles everything except the actual computations of your model. - To decide what will happen in your training loop, define the `training_step` function. +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager, suppress +from copy import copy, deepcopy +from typing import Optional -Below are all the things lightning automates for you in the training loop. +import numpy as np +import torch -Accumulated gradients ---------------------- +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing +from pytorch_lightning.utilities.distributed import rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.warnings import WarningCache + + +class TrainLoop: + + def __init__(self, trainer, multiple_trainloader_mode: str): + self.trainer = trainer + self.accumulated_loss = None + self.warning_cache = WarningCache() + self._teardown_already_run = False + self.running_loss = TensorRunningAccum(window_length=20) + self.automatic_optimization = True + self._curr_step_result = None + self._cur_grad_norm_dict = None + self._multiple_trainloader_mode = multiple_trainloader_mode + self._skip_backward = False + self.trainer._multiple_trainloader_mode = multiple_trainloader_mode + + def on_trainer_init( + self, + max_epochs: Optional[int], + min_epochs: Optional[int], + max_steps: Optional[int], + min_steps: Optional[int], + num_sanity_val_steps: int, + ) -> None: + self.trainer.global_step = 0 + self.trainer.current_epoch = 0 + self.trainer.should_stop = False + self.trainer._state = TrainerState.INITIALIZING + + self.trainer.total_batch_idx = 0 + self.trainer.batch_idx = 0 + self.trainer.num_training_batches = 0 + self.trainer.train_dataloader = None + + # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 + self.trainer.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs + # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 + self.trainer.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs + self.trainer.max_steps = max_steps + self.trainer.min_steps = min_steps + + if num_sanity_val_steps == -1: + self.trainer.num_sanity_val_steps = float("inf") + else: + self.trainer.num_sanity_val_steps = num_sanity_val_steps -Accumulated gradients runs K small batches of size N before doing a backwards pass. - The effect is a large effective batch size of size KxN. + @property + def num_optimizers(self): + num_optimizers = len(self.get_optimizers_iterable()) + return num_optimizers -.. code-block:: python + def should_skip_training(self): + should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps + should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs + return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 - # DEFAULT (ie: no accumulated grads) - trainer = Trainer(accumulate_grad_batches=1) + def on_train_start(self): + # hook + self.trainer.call_hook("on_train_start") -Force training for min or max epochs ------------------------------------- + def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): + # clean hparams + if hasattr(model, "hparams"): + parsing.clean_namespace(model.hparams) -It can be useful to force training for a minimum number of epochs or limit to a max number + # links data to the trainer + self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) -.. code-block:: python + # check that model is configured correctly + self.trainer.config_validator.verify_loop_configurations(model) - # DEFAULT - trainer = Trainer(min_epochs=1, max_epochs=1000) + # attach model log function to callback + self.trainer.callback_connector.attach_model_logging_functions(model) -Force disable early stop ------------------------- + def on_train_end(self): + if self._teardown_already_run: + return + self._teardown_already_run = True -To disable early stopping pass None to the early_stop_callback + # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates + # when a checkpoint was saved at the last step + self.trainer.global_step -= 1 + self.check_checkpoint_callback(should_update=True, is_last=True) + self.trainer.global_step += 1 -.. code-block:: python + # hook + self.trainer.call_hook("on_train_end") - # DEFAULT - trainer = Trainer(early_stop_callback=None) + # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. + # It might be related to xla tensors blocked when moving the cpu + # kill loggers + if self.trainer.logger is not None: + self.trainer.logger.finalize("success") -Gradient Clipping ------------------ + # summarize profile results + self.trainer.profiler.describe() -Gradient clipping may be enabled to avoid exploding gradients. - Specifically, this will `clip the gradient norm computed over all model parameters - `together `_. + # give accelerators a chance to finish + self.trainer.accelerator.on_train_end() -.. code-block:: python + # reset bookkeeping + self.trainer._running_stage = None - # DEFAULT (ie: don't clip) - trainer = Trainer(gradient_clip_val=0) + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks - # clip gradients with norm above 0.5 - trainer = Trainer(gradient_clip_val=0.5) + if is_last and any(cb.save_last and cb.verbose for cb in callbacks): + rank_zero_info("Saving latest checkpoint...") -Inspect gradient norms ----------------------- + model = self.trainer.lightning_module -Looking at grad norms can help you figure out where training might be going wrong. + for cb in callbacks: + cb.on_validation_end(self.trainer, model) -.. code-block:: python + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.lightning_module - # DEFAULT (-1 doesn't track norms) - trainer = Trainer(track_grad_norm=-1) + for cb in callbacks: + cb.on_validation_end(self.trainer, model) - # track the LP norm (P=2 here) - trainer = Trainer(track_grad_norm=2) + def on_train_epoch_start(self, epoch): -Set how much of the training set to check ------------------------------------------ + # update training progress in trainer + self.trainer.current_epoch = epoch -If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag. + model = self.trainer.lightning_module -train_percent_check will be overwritten by overfit_pct if `overfit_pct > 0` + # reset train dataloader + if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_train_dataloader(model) -.. code-block:: python + # todo: specify the possible exception + with suppress(Exception): + # set seed for distributed sampler (enables shuffling for each epoch) + self.trainer.train_dataloader.sampler.set_epoch(epoch) - # DEFAULT - trainer = Trainer(train_percent_check=1.0) + # changing gradient according accumulation_scheduler + self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) - # check 10% only - trainer = Trainer(train_percent_check=0.1) + # stores accumulated grad fractions per batch + self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) -Packed sequences as inputs --------------------------- + # hook + self.trainer.call_hook("on_epoch_start") + self.trainer.call_hook("on_train_epoch_start") -When using PackedSequence, do 2 things: -1. return either a padded tensor in dataset or a list of variable length tensors -in the dataloader collate_fn (example above shows the list implementation). -2. Pack the sequence in forward or training and validation steps depending on use case. + def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): + # hook + self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx) + self.trainer.call_hook('on_batch_end') -.. code-block:: python + # figure out what to track for epoch end + self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - # For use in dataloader - def collate_fn(batch): - x = [item[0] for item in batch] - y = [item[1] for item in batch] - return x, y + # reset batch logger internals + self.trainer.logger_connector.on_train_batch_end() - # In module - def training_step(self, batch, batch_idx): - x = rnn.pack_sequence(batch[0], enforce_sorted=False) - y = rnn.pack_sequence(batch[1], enforce_sorted=False) + def reset_train_val_dataloaders(self, model): + if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_train_dataloader(model) + if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: + self.trainer.reset_val_dataloader(model) -Truncated Backpropagation Through Time --------------------------------------- + def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): -There are times when multiple backwards passes are needed for each batch. - For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs. + # track the outputs to reduce at the end of the epoch + for opt_idx, opt_outputs in enumerate(batch_end_outputs): + sample_output = opt_outputs[-1] -When this flag is enabled each batch is split into sequences of size truncated_bptt_steps - and passed to training_step(...) separately. A default splitting function is provided, - however, you can override it for more flexibility. See `tbptt_split_batch`. + # decide if we need to reduce at the end of the epoch automatically + auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + hook_overridden = ( + is_overridden("training_epoch_end", model=self.trainer.lightning_module) + or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) + ) -.. code-block:: python + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if not (hook_overridden or auto_reduce_tng_result): + continue - # DEFAULT (single backwards pass per batch) - trainer = Trainer(truncated_bptt_steps=None) + # with 1 step (no tbptt) don't use a sequence at epoch end + if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): + opt_outputs = opt_outputs[0] - # (split batch into sequences of size 2) - trainer = Trainer(truncated_bptt_steps=2) + epoch_output[opt_idx].append(opt_outputs) + def get_optimizers_iterable(self): + """ + Generates an iterable with (idx, optimizer) for each optimizer. + """ + if not self.trainer.optimizer_frequencies: + # call training_step once per optimizer + return list(enumerate(self.trainer.optimizers)) -NaN detection and intervention ------------------------------- -When the `terminate_on_nan` flag is enabled, after every forward pass during training, Lightning will -check that + optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) + optimizers_loop_length = optimizer_freq_cumsum[-1] + current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length -1. the loss you return in `training_step` is finite (not NaN and not +/-inf) -2. the model parameters have finite values. + # find optimzier index by looking for the first {item > current_place} in the cumsum list + opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) + return [[opt_idx, self.trainer.optimizers[opt_idx]]] -Lightning will terminate the training loop with an error message if NaN or infinite -values are detected. If this happens, you should investigate numerically unstable operations -in your model. + def on_after_backward(self, training_step_output, batch_idx, untouched_loss): + is_result_obj = isinstance(training_step_output, Result) -.. code-block:: python + if is_result_obj: + training_step_output = training_step_output.detach() + else: + training_step_output.batch_loss = training_step_output.batch_loss.detach() - # DEFAULT (won't perform the NaN check) - trainer = Trainer(terminate_on_nan=False) + # insert after step hook + self.trainer.call_hook("on_after_backward") - # (NaN check each batch and terminate on NaN or infinite values) - trainer = Trainer(terminate_on_nan=True) + # when in dev debugging track the losses + self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) -""" + def _check_training_step_output(self, training_step_output): + if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: + if training_step_output.grad_fn is None: + # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... + raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") -from abc import ABC, abstractmethod -from typing import Callable -from typing import Union, List + def training_step(self, split_batch, batch_idx, opt_idx, hiddens): + # give the PL module a result for logging + model_ref = self.trainer.lightning_module -import numpy as np -from torch.utils.data import DataLoader -import torch + with self.trainer.profiler.profile("model_forward"): + args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) -from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import rank_zero_warn - -try: - from apex import amp -except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True - -try: - import torch_xla.distributed.parallel_loader as xla_pl - import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - - -class TrainerTrainLoopMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - max_epochs: int - min_epochs: int - on_gpu: bool - use_ddp: bool - use_dp: bool - use_ddp2: bool - use_horovod: bool - single_gpu: bool - use_tpu: bool - data_parallel_device_ids: ... - check_val_every_n_epoch: ... - num_training_batches: int - val_check_batch: ... - num_val_batches: int - disable_validation: bool - fast_dev_run: ... - accumulation_scheduler: ... - lr_schedulers: ... - enable_early_stop: ... - early_stop_callback: ... - callback_metrics: ... - logger: Union[LightningLoggerBase, bool] - global_step: int - testing: bool - log_save_interval: float - proc_rank: int - row_log_interval: float - truncated_bptt_steps: ... - optimizers: ... - optimizer_frequencies: ... - accumulate_grad_batches: int - track_grad_norm: ... - model: LightningModule - interrupted: bool - running_loss: ... - progress_bar_dict: ... - reduce_lr_on_plateau_scheduler: ... - profiler: ... - batch_idx: int - precision: ... - train_dataloader: DataLoader - reload_dataloaders_every_epoch: bool - max_steps: int - min_steps: int - total_batch_idx: int - checkpoint_callback: ... - terminate_on_nan: bool - - # Callback system - callbacks: List[Callback] - on_train_start: Callable - on_train_end: Callable - on_batch_start: Callable - on_batch_end: Callable - on_epoch_start: Callable - on_epoch_end: Callable - on_validation_end: Callable - - @abstractmethod - def get_model(self): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def is_function_implemented(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def run_evaluation(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def transfer_batch_to_gpu(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def transfer_batch_to_tpu(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def clip_gradients(self): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def detect_nan_tensors(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def is_overriden(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def add_progress_bar_metrics(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def log_metrics(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def process_output(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def reset_train_dataloader(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def reset_val_dataloader(self, model): - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def has_arg(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - - def train(self): - # get model - model = self.get_model() - - # load data - # if reload_dataloaders_every_epoch, this is moved to the epoch loop - if not self.reload_dataloaders_every_epoch: - self.reset_train_dataloader(model) - self.reset_val_dataloader(model) - - # Train start events - with self.profiler.profile('on_train_start'): - # callbacks - self.on_train_start() - # initialize early stop callback - if self.early_stop_callback is not None: - self.early_stop_callback.on_train_start(self, self.get_model()) - # model hooks - model.on_train_start() - - try: - # run all epochs - for epoch in range(self.current_epoch, self.max_epochs): - # reset train dataloader - if self.reload_dataloaders_every_epoch: - self.reset_train_dataloader(model) - # set seed for distributed sampler (enables shuffling for each epoch) - if (self.use_ddp or self.use_horovod) \ - and hasattr(self.train_dataloader.sampler, 'set_epoch'): - self.train_dataloader.sampler.set_epoch(epoch) - - # update training progress in trainer and model - model.current_epoch = epoch - self.current_epoch = epoch - - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_start(self, self.get_model()) - - # stores accumulated grad fractions per batch - self.batch_loss_value = TensorRunningAccum( - window_length=self.accumulate_grad_batches - ) + # manually capture logged metrics + model_ref._current_fx_name = 'training_step' + model_ref._results = Result() + with self.trainer.profiler.profile("training_step"): + training_step_output = self.trainer.accelerator.training_step(args) + self.trainer.accelerator.post_training_step() - # ----------------- - # RUN TNG EPOCH - # ----------------- - self.run_training_epoch() + self.trainer.logger_connector.cache_logged_metrics() - # update LR schedulers - self.update_learning_rates(interval='epoch') + self._check_training_step_output(training_step_output) - if self.max_steps and self.max_steps == self.global_step: - self.run_training_teardown() - return + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - # early stopping - met_min_epochs = epoch >= self.min_epochs - 1 - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( + training_step_output, split_batch + ) + is_result_obj = isinstance(training_step_output, Result) - # TODO wrap this logic into the callback - if self.enable_early_stop: - if (met_min_epochs and met_min_steps) or self.fast_dev_run: - should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) - # stop training - stop = should_stop and met_min_epochs - if stop: - self.run_training_teardown() - return + if training_step_output_for_epoch_end is None: + return None - self.run_training_teardown() + # enable empty loss when using manual opt + closure_loss = None + untouched_loss = None - except KeyboardInterrupt: - if self.proc_rank == 0: - log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') - self.interrupted = True - self.run_training_teardown() + if self.automatic_optimization: + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + if is_result_obj: + closure_loss = training_step_output.minimize + else: + closure_loss = training_step_output.batch_loss + + closure_loss = closure_loss / self.trainer.accumulate_grad_batches + + # the loss will get scaled for amp. avoid any modifications to it + untouched_loss = closure_loss.detach().clone() + + # result + result = AttributeDict( + closure_loss=closure_loss, + loss=untouched_loss, + training_step_output=training_step_output, + training_step_output_for_epoch_end=training_step_output_for_epoch_end, + ) + return result + + def _process_training_step_output(self, training_step_output, split_batch): + training_step_output_for_epoch_end = training_step_output + + # enable validation_step return None + if training_step_output_for_epoch_end is None: + return None, None + + # ----------------------------------------- + # process hybrid (1.0) + # ----------------------------------------- + # no need for these checks in 1.0.0 + # TODO: remove checks in 1.0.0 + is_tensor = isinstance(training_step_output_for_epoch_end, torch.Tensor) + is_1_0_output = is_tensor or ("log" not in training_step_output and "progress_bar" not in training_step_output) + if is_1_0_output: + return self._process_training_step_output_1_0(training_step_output, split_batch) + + # ----------------------------------------- + # process old dict (deprecate 1.0) + # ----------------------------------------- + training_step_output = self.trainer.process_dict_result(training_step_output, train=True) + + training_step_output = AttributeDict( + batch_loss=training_step_output[0], + pbar_on_batch_end=training_step_output[1], + log_metrics=training_step_output[2], + ) + # if the user decides to finally reduce things in epoch_end, save raw output without graphs + if isinstance(training_step_output_for_epoch_end, torch.Tensor): + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + else: + training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) + + return training_step_output_for_epoch_end, training_step_output + + def _process_training_step_output_1_0(self, training_step_output, split_batch): + result = self.trainer.lightning_module._results + + loss = None + hiddens = None + result["extra"] = {} + + # handle dict return + if isinstance(training_step_output, dict): + loss = training_step_output.pop("loss", None) + hiddens = training_step_output.pop("hiddens", None) + result["extra"] = training_step_output + + # handle scalar return + elif isinstance(training_step_output, torch.Tensor): + loss = training_step_output + + # map to results under the hood + result.minimize = loss + self.trainer.hiddens = hiddens + + # track batch for manual reduction with result + result.track_batch_size(len(split_batch)) + + # track metrics without grads for epoch reduction + training_step_output_for_epoch_end = copy(result) + training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + if self.trainer.move_metrics_to_cpu: + training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu() + + # what flows back into the system + training_step_output = result + + return training_step_output_for_epoch_end, training_step_output + + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): + model_ref = self.trainer.lightning_module + + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + + # native amp + lbfgs is a no go right now + if using_native_amp and is_lbfgs: + raise MisconfigurationException( + 'native PyTorch amp and lbfgs are not compatible.' + ' To request, please file a Github issue in PyTorch and tag @mcarilli' + ) + + # wraps into LightningOptimizer only for running step + optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + + # model hook + model_ref.optimizer_step( + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, + on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, + using_native_amp=using_native_amp, + using_lbfgs=is_lbfgs, + ) + + def on_before_zero_grad(self, optimizer): + self.trainer.call_hook('on_before_zero_grad', optimizer) + + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + + def track_and_norm_grad(self, optimizer): + # track gradient norms + grad_norm_dic = self._track_gradient_norm() + + # clip gradients + self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val) + self._cur_grad_norm_dict = grad_norm_dic + + def _track_gradient_norm(self): + grad_norm_dict = {} + if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: + if float(self.trainer.track_grad_norm) > 0: + model = self.trainer.lightning_module + grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) + return grad_norm_dict + + def tbptt_split_batch(self, batch): + splits = [batch] + if self.trainer.truncated_bptt_steps is not None: + model_ref = self.trainer.lightning_module + with self.trainer.profiler.profile("tbptt_split_batch"): + splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps) + return splits def run_training_epoch(self): + # modify dataloader if needed (ddp, etc...) + train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - # get model - model = self.get_model() + # track epoch output + epoch_output = [[] for _ in range(self.num_optimizers)] - # Epoch start events - with self.profiler.profile('on_epoch_start'): - # callbacks - self.on_epoch_start() + train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) + dataloader_idx = 0 + val_loop_called = False - # model hooks - if self.is_function_implemented('on_epoch_start'): - model.on_epoch_start() + for batch_idx, (batch, is_last_batch) in train_dataloader: - # track local dataloader so TPU can wrap each epoch - train_dataloader = self.train_dataloader + self.trainer.batch_idx = batch_idx - # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu: - device = xm.xla_device() - train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) - train_dataloader = train_dataloader.per_device_loader(device) + # ------------------------------------ + # TRAINING_STEP + TRAINING_STEP_END + # ------------------------------------ + with self.trainer.profiler.profile("run_training_batch"): + batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) - # bookkeeping - outputs = [] + # when returning -1 from train_step, we end epoch early + if batch_output.signal == -1: + break + + batch_end_outputs = self.process_train_step_outputs(batch_output.training_step_output_for_epoch_end) + # hook + # TODO: add outputs to batches + self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx) + + # ----------------------------------------- + # SAVE METRICS TO LOGGERS + # ----------------------------------------- + self.trainer.logger_connector.log_train_step_metrics(batch_output) + + # ----------------------------------------- + # VALIDATE IF NEEDED + CHECKPOINT CALLBACK + # ----------------------------------------- + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) + if should_check_val: + self.trainer.validating = True + self.trainer.run_evaluation() + self.trainer.training = True + val_loop_called = True + + # ----------------------------------------- + # SAVE LOGGERS (ie: Tensorboard, etc...) + # ----------------------------------------- + self.save_loggers_on_train_batch_end() + + # update LR schedulers + monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) + self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.trainer.checkpoint_connector.has_trained = True + + # max steps reached, end training + if ( + self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1 + and self._accumulated_batches_reached() + ): + break + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if self.trainer.should_stop: + break + + self.trainer.total_batch_idx += 1 - # run epoch - for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( - enumerate(_with_is_last(train_dataloader)), "get_train_batch" - ): # stop epoch if we limited the number of training batches - if batch_idx >= self.num_training_batches: + if self._num_training_batches_reached(is_last_batch): break - self.batch_idx = batch_idx + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() - model.global_step = self.global_step + # epoch end hook + self.on_train_epoch_end(epoch_output) - # --------------- - # RUN TRAIN STEP - # --------------- - _outputs = self.run_training_batch(batch, batch_idx) - batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs + # log epoch metrics + self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output, self.num_optimizers) - # only track outputs when user implementes training_epoch_end - # otherwise we will build up unecessary memory - if self.is_overriden('training_epoch_end', model=self.get_model()): - outputs.append(batch_output) + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) + should_train_only = self.trainer.disable_validation or should_skip_eval - # when returning -1 from train_step, we end epoch early - early_stop_epoch = batch_result == -1 - - # TODO: consolidate all actions that need to take place only after - # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - # update lr - self.update_learning_rates(interval='step') - - # --------------- - # RUN VAL STEP - # --------------- - is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 - can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - can_check_val = not self.disable_validation and can_check_epoch - should_check_val = is_val_check_batch or early_stop_epoch - should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) - should_check_val = can_check_val and should_check_val - - # --------------- - # CHECKPOINTING, EARLY STOPPING - # --------------- - # fast_dev_run always forces val checking after train batch - if self.fast_dev_run or should_check_val: - self.run_evaluation(test_mode=self.testing) - self.call_checkpoint_callback() - self.call_early_stop_callback() - - # when logs should be saved - should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch - if should_save_log or self.fast_dev_run: - if self.proc_rank == 0 and self.logger is not None: - self.logger.save() - - # when metrics should be logged - should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch - if should_log_metrics or self.fast_dev_run: - # logs user requested information to logger - self.log_metrics(batch_step_metrics, grad_norm_dic) + # update epoch level lr_schedulers if no val loop outside train loop is triggered + if (val_loop_called and not should_check_val) or should_train_only: + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - # progress global step according to grads progress - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - self.global_step += 1 - self.total_batch_idx += 1 + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) - # max steps reached, end training - if self.max_steps is not None and self.max_steps == self.global_step: - break + if should_check_val: + self.trainer.validating = True + self.trainer.run_evaluation(on_epoch=True) + self.trainer.training = True - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if early_stop_epoch or self.fast_dev_run: - break + # increment the global step once + # progress global step according to grads progress + self.increment_accumulated_grad_global_step() - if self.use_horovod: - hvd.join(hvd.local_rank() if self.on_gpu else -1) - - # process epoch outputs - model = self.get_model() - if self.is_overriden('training_epoch_end', model=model): - epoch_output = model.training_epoch_end(outputs) - _processed_outputs = self.process_output(epoch_output) - log_epoch_metrics = _processed_outputs[2] - callback_epoch_metrics = _processed_outputs[3] - self.log_metrics(log_epoch_metrics, {}) - self.callback_metrics.update(callback_epoch_metrics) - - # when no val loop is present or fast-dev-run still need to call checkpoints - if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val): - self.call_checkpoint_callback() - self.call_early_stop_callback() - - # Epoch end events - with self.profiler.profile('on_epoch_end'): - # callbacks - self.on_epoch_end() - # model hooks - if self.is_function_implemented('on_epoch_end'): - model.on_epoch_end() - - def run_training_batch(self, batch, batch_idx): + def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} - # track all metrics for callbacks - all_callback_metrics = [] + # bookkeeping + self.trainer.hiddens = None - # track metrics to log - all_log_metrics = [] + # track all outputs across time and num of optimizers + batch_outputs = [[] for _ in range(len(self.get_optimizers_iterable()))] if batch is None: - return 0, grad_norm_dic, {}, {} - - # Batch start events - with self.profiler.profile('on_batch_start'): - # callbacks - self.on_batch_start() - # hooks - if self.is_function_implemented('on_batch_start'): - response = self.get_model().on_batch_start(batch) - if response == -1: - return -1, grad_norm_dic, {}, {} + return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) - splits = [batch] - if self.truncated_bptt_steps is not None: - model_ref = self.get_model() - with self.profiler.profile('tbptt_split_batch'): - splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) + # hook + response = self.trainer.call_hook("on_batch_start") + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + + # hook + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) + if response == -1: + return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) + + # lightning module hook + splits = self.tbptt_split_batch(batch) - self.hiddens = None for split_idx, split_batch in enumerate(splits): - self.split_idx = split_idx - - for opt_idx, optimizer in self._get_optimizers_iterable(): - # make sure only the gradients of the current optimizer's paramaters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if len(self.optimizers) > 1: - for param in self.get_model().parameters(): - param.requires_grad = False - for group in optimizer.param_groups: - for param in group['params']: - param.requires_grad = True - - # wrap the forward step in a closure so second order methods work - def optimizer_closure(): - # forward pass - with self.profiler.profile('model_forward'): - if self.use_amp and self.use_native_amp: - with torch.cuda.amp.autocast(): - output_dict = self.training_forward(split_batch, batch_idx, - opt_idx, self.hiddens) - else: - output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens) - - # format and reduce outputs accordingly - processed_output = self.process_output(output_dict, train=True) - - closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output - - # accumulate loss - # (if accumulate_grad_batches = 1 no effect) - closure_loss = closure_loss / self.accumulate_grad_batches - - # backward pass - model_ref = self.get_model() - with self.profiler.profile('model_backward'): - model_ref.backward(self, closure_loss, optimizer, opt_idx) - - # track metrics for callbacks - all_callback_metrics.append(callback_metrics) - - # track progress bar metrics - self.add_progress_bar_metrics(progress_bar_metrics) - all_log_metrics.append(log_metrics) - - if self.use_horovod: - # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid - optimizer.synchronize() - - # insert after step hook - if self.is_function_implemented('on_after_backward'): - model_ref = self.get_model() - with self.profiler.profile('on_after_backward'): - model_ref.on_after_backward() - - return closure_loss, callback_metrics - - # calculate loss - loss, batch_output = optimizer_closure() - - # check if loss or model weights are nan - if self.terminate_on_nan: - self.detect_nan_tensors(loss) - # track total loss for logging (avoid mem leaks) - self.batch_loss_value.append(loss) + # create an iterable for optimizers and loop over them + for opt_idx, optimizer in self.prepare_optimizers(): + # toggle model params + set info to logger_connector + self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) + + if self.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + + # automatic_optimization=True: perform dpp sync only when performing optimizer_step + # automatic_optimization=False: don't block synchronization here + with self.block_ddp_sync_behaviour(): + self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ # gradient update with accumulated gradients - if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: - - # track gradient norms when requested - if batch_idx % self.row_log_interval == 0: - if self.track_grad_norm > 0: - model = self.get_model() - grad_norm_dic = model.grad_norm( - self.track_grad_norm) - - # clip gradients - if self.use_amp and self.use_native_amp: - self.scaler.unscale_(optimizer) - self.clip_gradients() - - # calls .step(), .zero_grad() - # override function to modify this behavior - model = self.get_model() - with self.profiler.profile('optimizer_step'): - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, - lambda: optimizer_closure()[0]) - - # calculate running loss for display - self.running_loss.append(self.batch_loss_value.mean()) - - # reset for next set of accumulated grads - self.batch_loss_value.reset() - - # Batch end events - with self.profiler.profile('on_batch_end'): - # callbacks - self.on_batch_end() - # model hooks - if self.is_function_implemented('on_batch_end'): - self.get_model().on_batch_end() - - # collapse all metrics into one dict - all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} - - # track all metrics for callbacks - self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) - - return 0, grad_norm_dic, all_log_metrics, batch_output - - def _get_optimizers_iterable(self): - if not self.optimizer_frequencies: - # call training_step once per optimizer - return list(enumerate(self.optimizers)) - optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies) - optimizers_loop_length = optimizer_freq_cumsum[-1] - current_place_in_loop = self.total_batch_idx % optimizers_loop_length + else: + if self.automatic_optimization: - # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) - return [(opt_idx, self.optimizers[opt_idx])] + def train_step_and_backward_closure(): + result = self.training_step_and_backward( + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) + return None if result is None else result.loss - def run_training_teardown(self): - # Train end events - with self.profiler.profile('on_train_end'): - # callbacks - self.on_train_end() - # model hooks - if self.is_function_implemented('on_train_end'): - self.get_model().on_train_end() + # optimizer step + self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) - if self.logger is not None: - self.logger.finalize("success") + else: + self._curr_step_result = self.training_step( + split_batch, batch_idx, opt_idx, self.trainer.hiddens + ) - # summarize profile results - self.profiler.describe() + if self._curr_step_result is None: + # user decided to skip optimization + # make sure to zero grad. + continue + + batch_outputs = self._process_closure_result( + batch_outputs=batch_outputs, + opt_idx=opt_idx, + ) + + # todo: Properly aggregate grad_norm accros opt_idx and split_idx + grad_norm_dic = self._cur_grad_norm_dict + self._cur_grad_norm_dict = None - def training_forward(self, batch, batch_idx, opt_idx, hiddens): + # update running loss + reset accumulated loss + self.update_running_loss() + + result = AttributeDict( + signal=0, + grad_norm_dic=grad_norm_dic, + training_step_output_for_epoch_end=batch_outputs, + ) + return result + + @contextmanager + def block_ddp_sync_behaviour(self, should_block_sync: bool = False): """ - Handle forward for each training case (distributed, single gpu, etc...) - :param batch: - :param batch_idx: - :return: + automatic_optimization = True + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + + automatic_optimization = False + do not block ddp gradient sync when using manual optimization + as gradients are needed within the training step + + Returns: + context manager with sync behaviour off + """ - # --------------- - # FORWARD - # --------------- + if ( + isinstance(self.trainer.training_type_plugin, ParallelPlugin) + and (self.automatic_optimization or should_block_sync) + ): + with self.trainer.training_type_plugin.block_backward_sync(): + yield None + else: + yield None + + def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: + opt_closure_result = self._curr_step_result + + if opt_closure_result is not None: + + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self.trainer.detect_nan_tensors(opt_closure_result.loss) + + # track all the outputs across all steps + batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 + batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) + + if self.automatic_optimization: + # track total loss for logging (avoid mem leaks) + self.accumulated_loss.append(opt_closure_result.loss) + + self._curr_step_result = None + + return batch_outputs + + def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + with self.trainer.profiler.profile("training_step_and_backward"): + # lightning module hook + result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) + self._curr_step_result = result + + if not self._skip_backward and self.automatic_optimization: + is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 + + if is_first_batch_to_accumulate: + self.on_before_zero_grad(optimizer) + self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) + + # backward pass + if result is not None: + with self.trainer.profiler.profile("backward"): + self.backward(result, optimizer, opt_idx) + + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + self.trainer.detect_nan_tensors(result.loss) + + else: + self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...") + + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) + + return result + + def backward(self, result, optimizer, opt_idx, *args, **kwargs): + self.trainer.dev_debugger.track_event("backward_call") + + should_accumulate = self.should_accumulate() + + # backward can be called manually in the training loop + if isinstance(result, torch.Tensor): + self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) + else: + result.closure_loss = self.trainer.accelerator.backward( + result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + ) + + if not self.should_accumulate(): + # track gradients + self.track_and_norm_grad(optimizer=optimizer) + + def update_train_loop_lr_schedulers(self, monitor_metrics=None): + num_accumulated_batches_reached = self._accumulated_batches_reached() + num_training_batches_reached = self._num_training_batches_reached() + + if num_accumulated_batches_reached or num_training_batches_reached: + # update lr + self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) + + def on_train_epoch_end(self, epoch_output): + # inform logger the batch loop has finished + self.trainer.logger_connector.on_train_epoch_end() + + self.trainer.call_hook('on_train_epoch_end', epoch_output) + self.trainer.call_hook('on_epoch_end') + + def increment_accumulated_grad_global_step(self): + num_accumulated_batches_reached = self._accumulated_batches_reached() + num_training_batches_reached = self._num_training_batches_reached() + + # progress global step according to grads progress + if num_accumulated_batches_reached or num_training_batches_reached: + self.trainer.global_step += 1 + + def _accumulated_batches_reached(self): + return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 + + def _num_training_batches_reached(self, is_last_batch=False): + return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch + + def should_accumulate(self): + # checks if backward or backward + optimizer step (via closure) + accumulation_done = self._accumulated_batches_reached() + is_final_batch = self._num_training_batches_reached() + return not (accumulation_done or is_final_batch) + + def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): + # decide if we should run validation + is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 + is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + can_check_val = self.trainer.enable_validation and is_val_check_epoch + is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") + epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 + + should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop + or is_last_batch_for_infinite_dataset + ) if on_epoch else (is_val_check_batch and not epoch_end_val_check) + + return should_check_val and can_check_val + + def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] - if len(self.optimizers) > 1: - if self.has_arg('training_step', 'optimizer_idx'): + if len(self.trainer.optimizers) > 1: + if self.trainer.has_arg("training_step", "optimizer_idx"): + if not self.automatic_optimization: + self.warning_cache.warn( + "`training_step` hook signature has changed in v1.3." + " `optimizer_idx` argument has been removed in case of manual optimization. Support for" + " the old signature will be removed in v1.5", DeprecationWarning + ) args.append(opt_idx) - else: - num_opts = len(self.optimizers) + elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization: raise ValueError( - f'Your LightningModule defines {num_opts} optimizers but ' - f'training_step is missing the "optimizer_idx" argument.' + f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" + ' `training_step` is missing the `optimizer_idx` argument.' ) # pass hiddens if using tbptt - if self.truncated_bptt_steps is not None: + if self.trainer.truncated_bptt_steps is not None: args.append(hiddens) - # distributed forward - if self.use_ddp or self.use_ddp2 or self.use_dp: - output = self.model(*args) - - # Horovod - elif self.use_horovod and self.on_gpu: - batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) - args[0] = batch - output = self.model.training_step(*args) - - # single GPU forward - elif self.single_gpu: - gpu_id = 0 - if isinstance(self.data_parallel_device_ids, list): - gpu_id = self.data_parallel_device_ids[0] - - # Don't copy the batch since there is a single gpu that the batch could - # be referenced from and if there are multiple optimizers the batch will - # wind up copying it to the same device repeatedly. - batch = self.transfer_batch_to_gpu(batch, gpu_id) - args[0] = batch - output = self.model.training_step(*args) - - # TPU support - elif self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) - args[0] = batch - output = self.model.training_step(*args) - - # CPU forward - else: - output = self.model.training_step(*args) - - # allow any mode to define training_step_end - # do something will all the dp outputs (like softmax) - if self.is_overriden('training_step_end'): - model_ref = self.get_model() - with self.profiler.profile('training_step_end'): - output = model_ref.training_step_end(output) - - # allow any mode to define training_end - # TODO: remove in 1.0.0 - if self.is_overriden('training_end'): - model_ref = self.get_model() - with self.profiler.profile('training_end'): - output = model_ref.training_end(output) + return args - rank_zero_warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.' - ' Use training_epoch_end instead', DeprecationWarning) + def save_loggers_on_train_batch_end(self): + # when loggers should save to disk + should_flush_logs = self.trainer.logger_connector.should_flush_logs + if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: + self.trainer.logger.save() - return output - - def update_learning_rates(self, interval: str): - """Update learning rates. - - Args: - interval: either 'epoch' or 'step'. + def process_train_step_outputs(self, all_train_step_outputs): """ - if not self.lr_schedulers: - return - - for lr_scheduler in self.lr_schedulers: - current_idx = self.batch_idx if interval == 'step' else self.current_epoch - current_idx += 1 # account for both batch and epoch starts from 0 - # Take step if call to update_learning_rates matches the interval key and - # the current step modulo the schedulers frequency is zero - if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0: - # If instance of ReduceLROnPlateau, we need to pass validation loss - if lr_scheduler['reduce_on_plateau']: - monitor_key = lr_scheduler['monitor'] - monitor_val = self.callback_metrics.get(monitor_key) - if monitor_val is None: - avail_metrics = ','.join(list(self.callback_metrics.keys())) - raise MisconfigurationException( - f'ReduceLROnPlateau conditioned on metric {monitor_key}' - f' which is not available. Available metrics are: {avail_metrics}.' - ' Condition can be set using `monitor` key in lr scheduler dict' - ) - lr_scheduler['scheduler'].step(monitor_val) - else: - lr_scheduler['scheduler'].step() - - def call_checkpoint_callback(self): - if self.checkpoint_callback is not None: - self.checkpoint_callback.on_validation_end(self, self.get_model()) - - def call_early_stop_callback(self): - if self.early_stop_callback: - self.early_stop_callback.on_epoch_end(self, self.get_model()) - - -def _with_is_last(iterable): - """Pass through values from the given iterable with an added boolean indicating if this is the last item. - See `https://stackoverflow.com/a/1630350 `_""" - it = iter(iterable) - last = next(it) - for val in it: - # yield last and has next - yield last, False - last = val - # yield last, no longer has next - yield last, True + Figure out what needs to be tracked/logged at the end of the epoch + """ + # the training step outputs a list per optimizer. The list contains the outputs at each time step + # when no TBPTT is used, then the list has 1 item per batch + # when TBPTT IS used, then the list has n items (1 per time step) + # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer + return [opt_idx_out for opt_idx_out in all_train_step_outputs if len(opt_idx_out)] + + def prepare_optimizers(self): + # in manual optimization we loop over all optimizers at once + optimizers = self.get_optimizers_iterable() + if not self.automatic_optimization: + optimizers = [optimizers[0]] + return optimizers + + def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): + # set split_idx to trainer for tracking + self.trainer.split_idx = split_idx + + # make sure only the gradients of the current optimizer's parameters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if self.automatic_optimization and len(self.trainer.optimizers) > 1: + model = self.trainer.lightning_module + model.toggle_optimizer(optimizer, opt_idx) + + # use to track metrics internally + self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) + + def update_running_loss(self): + accumulated_loss = self.accumulated_loss.mean() + + if accumulated_loss is not None: + # calculate running loss for display + self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + + # reset for next set of accumulated grads + self.accumulated_loss.reset() diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 0d86d53b7bbc4b..2795dd4f0af30b 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -1,68 +1,48 @@ -import math -import sys -from abc import ABC, abstractmethod +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC import torch from torch import Tensor -from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.core.lightning import LightningModule EPSILON = 1e-6 EPSILON_FP16 = 1e-5 +log = logging.getLogger(__name__) class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - gradient_clip_val: ... - precision: ... - - @abstractmethod - def get_model(self): - """Warning: this is just empty shell for code implemented in other class.""" - - def clip_gradients(self): - - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md - if self.gradient_clip_val > 0: - model = self.get_model() - parameters = model.parameters() - max_norm = float(self.gradient_clip_val) - norm_type = float(2.0) - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - if norm_type == math.inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - else: - device = parameters[0].device - total_norm = torch.zeros([], device=device if parameters else None) - for p in parameters: - param_norm = p.grad.data.pow(norm_type).sum() - total_norm.add_(param_norm) - total_norm = (total_norm ** (1. / norm_type)) - eps = EPSILON_FP16 if self.precision == 16 else EPSILON - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) - for p in parameters: - p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) + lightning_module: LightningModule def print_nan_gradients(self) -> None: - model = self.get_model() + model = self.lightning_module for param in model.parameters(): if (param.grad is not None) and torch.isnan(param.grad.float()).any(): log.info(param, param.grad) def detect_nan_tensors(self, loss: Tensor) -> None: - model = self.get_model() + model = self.lightning_module # check if loss is nan if not torch.isfinite(loss).all(): - raise ValueError( - 'The loss returned in `training_step` is nan or inf.' - ) + raise ValueError('The loss returned in `training_step` is nan or inf.') # check if a network weight is nan for name, param in model.named_parameters(): if not torch.isfinite(param).all(): @@ -71,12 +51,3 @@ def detect_nan_tensors(self, loss: Tensor) -> None: f'Detected nan and/or inf values in `{name}`.' ' Check your forward pass for numerically unstable operations.' ) - - def configure_accumulated_gradients(self, accumulate_grad_batches): - if isinstance(accumulate_grad_batches, dict): - self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) - elif isinstance(accumulate_grad_batches, int): - schedule = {1: accumulate_grad_batches} - self.accumulation_scheduler = GradientAccumulationScheduler(schedule) - else: - raise TypeError("Gradient accumulation supports only int and dict types") diff --git a/pytorch_lightning/tuner/__init__.py b/pytorch_lightning/tuner/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/pytorch_lightning/tuner/auto_gpu_select.py b/pytorch_lightning/tuner/auto_gpu_select.py new file mode 100644 index 00000000000000..3bd1ce52b52f47 --- /dev/null +++ b/pytorch_lightning/tuner/auto_gpu_select.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def pick_multiple_gpus(nb): + ''' + Raises: + MisconfigurationException: + If ``gpus`` is set to 0, when ``auto_select_gpus=True``. + ''' + if nb == 0: + raise MisconfigurationException( + r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ + Please select a valid number of GPU resources when using auto_select_gpus." + ) + + nb = torch.cuda.device_count() if nb == -1 else nb + + picked = [] + for _ in range(nb): + picked.append(pick_single_gpu(exclude_gpus=picked)) + + return picked + + +def pick_single_gpu(exclude_gpus: list): + ''' + Raises: + RuntimeError: + If you try to allocate a GPU, when no GPUs are available. + ''' + for i in range(torch.cuda.device_count()): + if i in exclude_gpus: + continue + # Try to allocate on device: + device = torch.device(f"cuda:{i}") + try: + torch.ones(1).to(device) + except RuntimeError: + continue + return i + raise RuntimeError("No GPUs available.") diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py new file mode 100644 index 00000000000000..9c5e966c14cc1c --- /dev/null +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -0,0 +1,295 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import logging +import os +from typing import Optional, Tuple + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.data import has_len +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error +from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr + +log = logging.getLogger(__name__) + + +def scale_batch_size( + trainer, + model: LightningModule, + mode: str = 'power', + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + batch_arg_name: str = 'batch_size', + **fit_kwargs +): + r""" + Will iteratively try to find the largest batch size for a given model + that does not give an out of memory (OOM) error. + + Args: + trainer: The Trainer + model: Model to fit. + + mode: string setting the search mode. Either `power` or `binsearch`. + If mode is `power` we keep multiplying the batch size by 2, until + we get an OOM error. If mode is 'binsearch', we will initially + also keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the + batch size that failed. + + steps_per_trial: number of steps to run with a given batch size. + Idealy 1 should be enough to test if a OOM error occurs, + however in practise a few are needed + + init_val: initial batch size to start the search with + + max_trials: max number of increase in batch size done before + algorithm is terminated + + batch_arg_name: name of the attribute that stores the batch size. + It is expected that the user has provided a model or datamodule that has a hyperparameter + with that name. We will look for this attribute name in the following places + + - ``model`` + - ``model.hparams`` + - ``model.datamodule`` + - ``trainer.datamodule`` (the datamodule passed to the tune method) + + **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader + or datamodule. + + Raises: + MisconfigurationException: + If field ``batch_arg_name`` is not found in ``model`` and ``model.hparams``, or + if batch scaling feature is used with dataloaders passed directly to ``.fit()``. + ValueError: + If mode in method ``scale_batch_size`` is neither ``power`` nor ``binsearch``. + """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning) + return + + if not lightning_hasattr(model, batch_arg_name): + raise MisconfigurationException(f'Field {batch_arg_name} not found in both `model` and `model.hparams`') + if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams: + rank_zero_warn( + f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!' + f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.' + f' If this is not the intended behavior, please remove either one.' + ) + + if hasattr(model.train_dataloader, 'patch_loader_code'): + raise MisconfigurationException( + 'The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`.' + ' Please disable the feature or incorporate the dataloader into the model.' + ) + + # Arguments we adjust during the batch size finder, save for restoring + __scale_batch_dump_params(trainer) + + # Set to values that are required by the algorithm + __scale_batch_reset_params(trainer, model, steps_per_trial) + + # Save initial model, that is loaded after batch size is found + save_path = os.path.join(trainer.default_root_dir, 'scale_batch_size_temp_model.ckpt') + trainer.save_checkpoint(str(save_path)) + + if trainer.progress_bar_callback: + trainer.progress_bar_callback.disable() + + # Initially we just double in size until an OOM is encountered + new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val + if mode == 'power': + new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) + elif mode == 'binsearch': + new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) + else: + raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') + + garbage_collection_cuda() + log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}') + + # Restore initial state of model + if trainer.is_global_zero: + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) + fs = get_filesystem(str(save_path)) + if fs.exists(save_path): + fs.rm(save_path) + + # Finish by resetting variables so trainer is ready to fit model + __scale_batch_restore_params(trainer) + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + return new_size + + +def __scale_batch_dump_params(trainer): + # Prevent going into infinite loop + trainer.__dumped_params = { + 'auto_lr_find': trainer.auto_lr_find, + 'current_epoch': trainer.current_epoch, + 'max_steps': trainer.max_steps, + 'weights_summary': trainer.weights_summary, + 'logger': trainer.logger, + 'callbacks': trainer.callbacks, + 'checkpoint_callback': trainer.checkpoint_callback, + 'auto_scale_batch_size': trainer.auto_scale_batch_size, + 'limit_train_batches': trainer.limit_train_batches, + 'model': trainer.model, + } + + +def __scale_batch_reset_params(trainer, model, steps_per_trial): + trainer.auto_scale_batch_size = None # prevent recursion + trainer.auto_lr_find = False # avoid lr find being called multiple times + trainer.current_epoch = 0 + trainer.max_steps = steps_per_trial # take few steps + trainer.weights_summary = None # not needed before full run + trainer.logger = DummyLogger() + trainer.callbacks = [] # not needed before full run + trainer.limit_train_batches = 1.0 + trainer.optimizers, trainer.schedulers = [], [] # required for saving + trainer.model = model # required for saving + + +def __scale_batch_restore_params(trainer): + trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] + trainer.current_epoch = trainer.__dumped_params['current_epoch'] + trainer.max_steps = trainer.__dumped_params['max_steps'] + trainer.weights_summary = trainer.__dumped_params['weights_summary'] + trainer.logger = trainer.__dumped_params['logger'] + trainer.callbacks = trainer.__dumped_params['callbacks'] + trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size'] + trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches'] + trainer.model = trainer.__dumped_params['model'] + del trainer.__dumped_params + + +def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): + """ Batch scaling mode where the size is doubled at each iteration until an + OOM error is encountered. """ + for _ in range(max_trials): + garbage_collection_cuda() + trainer.global_step = 0 # reset after each try + try: + # Try fit + trainer.fit(model, **fit_kwargs) + # Double in size + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + except RuntimeError as exception: + # Only these errors should trigger an adjustment + if is_oom_error(exception): + # If we fail in power mode, half the size and return + garbage_collection_cuda() + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') + break + else: + raise # some other error not memory related + + if not changed: + break + return new_size + + +def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): + """ Batch scaling mode where the size is initially is doubled at each iteration + until an OOM error is encountered. Hereafter, the batch size is further + refined using a binary search """ + high = None + count = 0 + while True: + garbage_collection_cuda() + trainer.global_step = 0 # reset after each try + try: + # Try fit + trainer.fit(model, **fit_kwargs) + count += 1 + if count > max_trials: + break + # Double in size + low = new_size + if high: + if high - low <= 1: + break + midval = (high + low) // 2 + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') + else: + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + + if not changed: + break + + except RuntimeError as exception: + # Only these errors should trigger an adjustment + if is_oom_error(exception): + # If we fail in power mode, half the size and return + garbage_collection_cuda() + high = new_size + midval = (high + low) // 2 + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='failed') + if high - low <= 1: + break + else: + raise # some other error not memory related + + return new_size + + +def _adjust_batch_size( + trainer, + batch_arg_name: str = 'batch_size', + factor: float = 1.0, + value: Optional[int] = None, + desc: Optional[str] = None +) -> Tuple[int, bool]: + """ Helper function for adjusting the batch size. + + Args: + trainer: instance of pytorch_lightning.Trainer + + batch_arg_name: name of the field where batch_size is stored. + + factor: value which the old batch size is multiplied by to get the + new batch size + + value: if a value is given, will override the batch size with this value. + Note that the value of `factor` will not have an effect in this case + + desc: either `succeeded` or `failed`. Used purely for logging + + Returns: + The new batch size for the next trial and a bool that signals whether the + new value is different than the previous batch size. + """ + model = trainer.lightning_module + batch_size = lightning_getattr(model, batch_arg_name) + new_size = value if value is not None else int(batch_size * factor) + if desc: + log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') + + if not _is_valid_batch_size(new_size, trainer.train_dataloader): + new_size = min(new_size, len(trainer.train_dataloader.dataset)) + + changed = new_size != batch_size + lightning_setattr(model, batch_arg_name, new_size) + return new_size, changed + + +def _is_valid_batch_size(current_size, dataloader): + return not has_len(dataloader) or current_size <= len(dataloader) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py new file mode 100644 index 00000000000000..e3ccef9aa76e20 --- /dev/null +++ b/pytorch_lightning/tuner/lr_finder.py @@ -0,0 +1,513 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging +import os +from functools import wraps +from typing import Callable, List, Optional, Sequence, Union + +import numpy as np +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr + +# check if ipywidgets is installed before importing tqdm.auto +# to ensure it won't fail and a progress bar is displayed +if importlib.util.find_spec('ipywidgets') is not None: + from tqdm.auto import tqdm +else: + from tqdm import tqdm + +log = logging.getLogger(__name__) + + +def _determine_lr_attr_name(trainer, model: LightningModule) -> str: + if isinstance(trainer.auto_lr_find, str): + if not lightning_hasattr(model, trainer.auto_lr_find): + raise MisconfigurationException( + f'`auto_lr_find` was set to {trainer.auto_lr_find}, however' + ' could not find this as a field in `model` or `model.hparams`.' + ) + return trainer.auto_lr_find + + attr_options = ('lr', 'learning_rate') + for attr in attr_options: + if lightning_hasattr(model, attr): + return attr + + raise MisconfigurationException( + 'When `auto_lr_find=True`, either `model` or `model.hparams` should' + f' have one of these fields: {attr_options} overridden.' + ) + + +def lr_find( + trainer, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + min_lr: float = 1e-8, + max_lr: float = 1, + num_training: int = 100, + mode: str = 'exponential', + early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None, + update_attr: bool = False, +): + r""" + ``lr_find`` enables the user to do a range test of good initial learning rates, + to reduce the amount of guesswork in picking a good starting learning rate. + + Args: + model: Model to do range testing for + + train_dataloader: A PyTorch + ``DataLoader`` with training samples. If the model has + a predefined train_dataloader method, this will be skipped. + + min_lr: minimum learning rate to investigate + + max_lr: maximum learning rate to investigate + + num_training: number of learning rates to test + + mode: Search strategy to update learning rate after each batch: + + - ``'exponential'`` (default): Will increase the learning rate exponentially. + - ``'linear'``: Will increase the learning rate linearly. + + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + datamodule: An optional ``LightningDataModule`` which holds the training + and validation dataloader(s). Note that the ``train_dataloader`` and + ``val_dataloaders`` parameters cannot be used at the same time as + this parameter, or a ``MisconfigurationException`` will be raised. + + update_attr: Whether to update the learning rate attribute or not. + + Raises: + MisconfigurationException: + If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or + if you are using `more than one optimizer` with learning rate finder. + + Example:: + + # Setup model and trainer + model = MyModelClass(hparams) + trainer = pl.Trainer() + + # Run lr finder + lr_finder = trainer.tuner.lr_find(model, ...) + + # Inspect results + fig = lr_finder.plot(); fig.show() + suggested_lr = lr_finder.suggestion() + + # Overwrite lr and create new model + hparams.lr = suggested_lr + model = MyModelClass(hparams) + + # Ready to train with new learning rate + trainer.fit(model) + + """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) + return + + # Determine lr attr + if update_attr: + lr_attr_name = _determine_lr_attr_name(trainer, model) + + save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') + + __lr_finder_dump_params(trainer, model) + + # Prevent going into infinite loop + trainer.auto_lr_find = False + + # Initialize lr finder object (stores results) + lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) + + # Use special lr logger callback + trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] + + # No logging + trainer.logger = DummyLogger() + + # Max step set to number of iterations + trainer.max_steps = num_training + + # Disable standard progress bar for fit + if trainer.progress_bar_callback: + trainer.progress_bar_callback.disable() + + # Required for saving the model + trainer.optimizers, trainer.schedulers = [], [], + trainer.model = model + + # Dump model checkpoint + trainer.save_checkpoint(str(save_path)) + + # Configure optimizer and scheduler + model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) + + # Fit, lr & loss logged in callback + trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + + # Prompt if we stopped early + if trainer.global_step != num_training: + log.info('LR finder stopped early due to diverging loss.') + + # Transfer results from callback to lr finder object + lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) + lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose + + # Reset model state + if trainer.is_global_zero: + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) + fs = get_filesystem(str(save_path)) + if fs.exists(save_path): + fs.rm(save_path) + + # Finish by resetting variables so trainer is ready to fit model + __lr_finder_restore_params(trainer, model) + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + # Update lr attr if required + if update_attr: + lr = lr_finder.suggestion() + + # TODO: log lr.results to self.logger + lightning_setattr(model, lr_attr_name, lr) + log.info(f'Learning rate set to {lr}') + + return lr_finder + + +def __lr_finder_dump_params(trainer, model): + # Prevent going into infinite loop + trainer.__dumped_params = { + 'auto_lr_find': trainer.auto_lr_find, + 'callbacks': trainer.callbacks, + 'logger': trainer.logger, + 'max_steps': trainer.max_steps, + 'checkpoint_callback': trainer.checkpoint_callback, + 'configure_optimizers': model.configure_optimizers, + } + + +def __lr_finder_restore_params(trainer, model): + trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] + trainer.logger = trainer.__dumped_params['logger'] + trainer.callbacks = trainer.__dumped_params['callbacks'] + trainer.max_steps = trainer.__dumped_params['max_steps'] + model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] + del trainer.__dumped_params + + +class _LRFinder(object): + """ LR finder object. This object stores the results of Trainer.lr_find(). + + Args: + mode: either `linear` or `exponential`, how to increase lr after each step + + lr_min: lr to start search from + + lr_max: lr to stop search + + num_training: number of steps to take between lr_min and lr_max + + Example:: + # Run lr finder + lr_finder = trainer.lr_find(model) + + # Results stored in + lr_finder.results + + # Plot using + lr_finder.plot() + + # Get suggestion + lr = lr_finder.suggestion() + """ + + def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): + assert mode in ('linear', 'exponential'), \ + 'mode should be either `linear` or `exponential`' + + self.mode = mode + self.lr_min = lr_min + self.lr_max = lr_max + self.num_training = num_training + + self.results = {} + self._total_batch_idx = 0 # for debug purpose + + def _exchange_scheduler(self, configure_optimizers: Callable): + """ Decorate configure_optimizers methods such that it returns the users + originally specified optimizer together with a new scheduler that + that takes care of the learning rate search. + """ + + @wraps(configure_optimizers) + def func(): + # Decide the structure of the output from configure_optimizers + # Same logic as method `init_optimizers` in trainer/optimizers.py + optim_conf = configure_optimizers() + if isinstance(optim_conf, Optimizer): + optimizers = [optim_conf] + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ + and isinstance(optim_conf[0], list): + optimizers, _ = optim_conf + elif isinstance(optim_conf, dict): + optimizers = [optim_conf["optimizer"]] + elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + elif isinstance(optim_conf, (list, tuple)): + optimizers = [optim_conf] + + if len(optimizers) != 1: + raise MisconfigurationException( + f'`model.configure_optimizers()` returned {len(optimizers)}, but' + ' learning rate finder only works with single optimizer' + ) + + optimizer = optimizers[0] + + new_lrs = [self.lr_min] * len(optimizer.param_groups) + for param_group, new_lr in zip(optimizer.param_groups, new_lrs): + param_group["lr"] = new_lr + param_group["initial_lr"] = new_lr + + args = (optimizer, self.lr_max, self.num_training) + scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args) + + return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] + + return func + + def plot(self, suggest: bool = False, show: bool = False): + """ Plot results from lr_find run + Args: + suggest: if True, will mark suggested lr to use with a red point + + show: if True, will show figure + """ + import matplotlib.pyplot as plt + + lrs = self.results["lr"] + losses = self.results["loss"] + + fig, ax = plt.subplots() + + # Plot loss as a function of the learning rate + ax.plot(lrs, losses) + if self.mode == 'exponential': + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if suggest: + _ = self.suggestion() + if self._optimal_idx: + ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker='o', color='red') + + if show: + plt.show() + + return fig + + def suggestion(self, skip_begin: int = 10, skip_end: int = 1): + """ This will propose a suggestion for choice of initial learning rate + as the point with the steepest negative gradient. + + Returns: + lr: suggested initial learning rate to use + skip_begin: how many samples to skip in the beginning. Prevent too naive estimates + skip_end: how many samples to skip in the end. Prevent too optimistic estimates + + """ + try: + loss = np.array(self.results["loss"][skip_begin:-skip_end]) + loss = loss[np.isfinite(loss)] + min_grad = np.gradient(loss).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] + # todo: specify the possible exception + except Exception: + log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') + self._optimal_idx = None + + +class _LRCallback(Callback): + """ Special callback used by the learning rate finder. This callbacks log + the learning rate before each batch and log the corresponding loss after + each batch. + + Args: + num_training: number of iterations done by the learning rate finder + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than ``early_stop_threshold*best_loss`` + then the search is stopped. To disable, set to ``None``. + progress_bar_refresh_rate: rate to refresh the progress bar for + the learning rate finder + beta: smoothing value, the loss being logged is a running average of + loss values logged until now. ``beta`` controls the forget rate i.e. + if ``beta=0`` all past information is ignored. + + """ + + def __init__( + self, + num_training: int, + early_stop_threshold: float = 4.0, + progress_bar_refresh_rate: int = 0, + beta: float = 0.98 + ): + self.num_training = num_training + self.early_stop_threshold = early_stop_threshold + self.beta = beta + self.losses = [] + self.lrs = [] + self.avg_loss = 0.0 + self.best_loss = 0.0 + self.progress_bar_refresh_rate = progress_bar_refresh_rate + self.progress_bar = None + + def on_batch_start(self, trainer, pl_module): + """ Called before each training batch, logs the lr that will be used """ + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + + if self.progress_bar_refresh_rate and self.progress_bar is None: + self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) + + self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + """ Called when the training batch ends, logs the calculated loss """ + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + + if self.progress_bar: + self.progress_bar.update() + + current_loss = trainer.train_loop.running_loss.last().item() + current_step = trainer.global_step + + # Avg loss (loss with momentum) + smoothing + self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss + smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1)) + + # Check if we diverging + if self.early_stop_threshold is not None: + if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: + trainer.max_steps = current_step # stop signal + if self.progress_bar: + self.progress_bar.close() + + # Save best loss for diverging checking + if smoothed_loss < self.best_loss or current_step == 1: + self.best_loss = smoothed_loss + + self.losses.append(smoothed_loss) + + +class _LinearLR(_LRScheduler): + """Linearly increases the learning rate between two boundaries + over a number of iterations. + Arguments: + + optimizer: wrapped optimizer. + + end_lr: the final learning rate. + + num_iter: the number of iterations over which the test occurs. + + last_epoch: the index of last epoch. Default: -1. + """ + last_epoch: int + base_lrs: Sequence + + def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): + self.end_lr = end_lr + self.num_iter = num_iter + super(_LinearLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + curr_iter = self.last_epoch + 1 + r = curr_iter / self.num_iter + + if self.last_epoch > 0: + val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + else: + val = [base_lr for base_lr in self.base_lrs] + self._lr = val + return val + + @property + def lr(self): + return self._lr + + +class _ExponentialLR(_LRScheduler): + """Exponentially increases the learning rate between two boundaries + over a number of iterations. + + Arguments: + + optimizer: wrapped optimizer. + + end_lr: the final learning rate. + + num_iter: the number of iterations over which the test occurs. + + last_epoch: the index of last epoch. Default: -1. + """ + last_epoch: int + base_lrs: Sequence + + def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): + self.end_lr = end_lr + self.num_iter = num_iter + super(_ExponentialLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + curr_iter = self.last_epoch + 1 + r = curr_iter / self.num_iter + + if self.last_epoch > 0: + val = [base_lr * (self.end_lr / base_lr)**r for base_lr in self.base_lrs] + else: + val = [base_lr for base_lr in self.base_lrs] + self._lr = val + return val + + @property + def lr(self): + return self._lr diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py new file mode 100644 index 00000000000000..b9fa9afe0e77ec --- /dev/null +++ b/pytorch_lightning/tuner/tuning.py @@ -0,0 +1,155 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from torch.utils.data import DataLoader + +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus +from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size +from pytorch_lightning.tuner.lr_finder import lr_find + + +class Tuner: + + def __init__(self, trainer): + self.trainer = trainer + + def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): + self.trainer.auto_lr_find = auto_lr_find + self.trainer.auto_scale_batch_size = auto_scale_batch_size + + def setup_trainer( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: LightningDataModule = None, + ): + self.trainer.model_connector.copy_trainer_model_properties(model) + # setup data, etc... + self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) + # hook + self.trainer.data_connector.prepare_data(model) + + def tune(self, model, train_dataloader, val_dataloaders, datamodule): + # Run auto batch size scaling + if self.trainer.auto_scale_batch_size: + if isinstance(self.trainer.auto_scale_batch_size, bool): + self.trainer.auto_scale_batch_size = 'power' + self.scale_batch_size( + model, + mode=self.trainer.auto_scale_batch_size, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + ) + + # Run learning rate finder: + if self.trainer.auto_lr_find: + self.lr_find(model, update_attr=True) + + self.trainer.state = TrainerState.FINISHED + + def scale_batch_size( + self, + model, + mode: str = 'power', + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + batch_arg_name: str = 'batch_size', + **fit_kwargs + ): + r""" + Will iteratively try to find the largest batch size for a given model + that does not give an out of memory (OOM) error. + + Args: + model: Model to fit. + + mode: string setting the search mode. Either `power` or `binsearch`. + If mode is `power` we keep multiplying the batch size by 2, until + we get an OOM error. If mode is 'binsearch', we will initially + also keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the + batch size that failed. + + steps_per_trial: number of steps to run with a given batch size. + Idealy 1 should be enough to test if a OOM error occurs, + however in practise a few are needed + + init_val: initial batch size to start the search with + + max_trials: max number of increase in batch size done before + algorithm is terminated + + batch_arg_name: name of the attribute that stores the batch size. + It is expected that the user has provided a model or datamodule that has a hyperparameter + with that name. We will look for this attribute name in the following places + + - ``model`` + - ``model.hparams`` + - ``model.datamodule`` + - ``trainer.datamodule`` (the datamodule passed to the tune method) + + **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader + or datamodule. + + """ + self.setup_trainer(model, **fit_kwargs) + return scale_batch_size( + self.trainer, + model, + mode, + steps_per_trial, + init_val, + max_trials, + batch_arg_name, + **fit_kwargs, + ) + + def lr_find( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + min_lr: float = 1e-8, + max_lr: float = 1, + num_training: int = 100, + mode: str = 'exponential', + early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None, + update_attr: bool = False, + ): + self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) + return lr_find( + self.trainer, + model, + train_dataloader, + val_dataloaders, + min_lr, + max_lr, + num_training, + mode, + early_stop_threshold, + datamodule, + update_attr, + ) + + def pick_multiple_gpus(self, num_gpus: int): + return pick_multiple_gpus(num_gpus) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index c8bc28052398b2..03981b0042eac9 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,3 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """General utilities""" -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +import numpy +from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 +from pytorch_lightning.utilities.distributed import ( # noqa: F401 + AllGatherGrad, + rank_zero_deprecation, + rank_zero_info, + rank_zero_only, + rank_zero_warn, +) +from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401 +from pytorch_lightning.utilities.imports import ( # noqa: F401 + _APEX_AVAILABLE, + _BOLTS_AVAILABLE, + _DEEPSPEED_AVAILABLE, + _FAIRSCALE_AVAILABLE, + _FAIRSCALE_PIPE_AVAILABLE, + _GROUP_AVAILABLE, + _HOROVOD_AVAILABLE, + _HYDRA_AVAILABLE, + _HYDRA_EXPERIMENTAL_AVAILABLE, + _IS_INTERACTIVE, + _module_available, + _NATIVE_AMP_AVAILABLE, + _OMEGACONF_AVAILABLE, + _RPC_AVAILABLE, + _TORCH_GREATER_EQUAL_1_6, + _TORCH_GREATER_EQUAL_1_7, + _TORCH_LOWER_EQUAL_1_4, + _TORCH_QUANTIZE_AVAILABLE, + _TORCHTEXT_AVAILABLE, + _TORCHVISION_AVAILABLE, + _XLA_AVAILABLE, +) +from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 +from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 + +_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + +FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps +FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps +FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py new file mode 100644 index 00000000000000..e100a803bcd003 --- /dev/null +++ b/pytorch_lightning/utilities/apply_func.py @@ -0,0 +1,175 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import operator +from abc import ABC +from collections.abc import Mapping, Sequence +from copy import copy +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch + +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE + +if _TORCHTEXT_AVAILABLE: + if _compare_version("torchtext", operator.ge, "0.9.0"): + from torchtext.legacy.data import Batch + else: + from torchtext.data import Batch +else: + Batch = type(None) + + +def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): + if device is None: + raise MisconfigurationException("device (torch.device) should be provided.") + return torch.tensor(value, dtype=dtype, device=device) + + +def from_numpy(value, device: torch.device = None): + if device is None: + raise MisconfigurationException("device (torch.device) should be provided.") + return torch.from_numpy(value).to(device) + + +CONVERSION_DTYPES = [ + # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group + (bool, partial(to_dtype_tensor, dtype=torch.uint8)), + (int, partial(to_dtype_tensor, dtype=torch.int)), + (float, partial(to_dtype_tensor, dtype=torch.float)), + (np.ndarray, from_numpy), +] + + +def apply_to_collection( + data: Any, + dtype: Union[type, tuple], + function: Callable, + *args, + wrong_dtype: Optional[Union[type, tuple]] = None, + **kwargs +) -> Any: + """ + Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections is of + the :attr:`wrong_type` even if it is of type :attr`dtype` + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + the resulting collection + """ + elem_type = type(data) + + # Breaking condition + if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): + return function(data, *args, **kwargs) + + # Recursively apply to collection items + if isinstance(data, Mapping): + return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) + + if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + + if isinstance(data, Sequence) and not isinstance(data, str): + return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + + # data is neither of dtype, nor a collection + return data + + +class TransferableDataType(ABC): + """ + A custom type for data that can be moved to a torch device via `.to(...)`. + Example: + >>> isinstance(dict, TransferableDataType) + False + >>> isinstance(torch.rand(2, 3), TransferableDataType) + True + >>> class CustomObject: + ... def __init__(self): + ... self.x = torch.rand(2, 2) + ... def to(self, device): + ... self.x = self.x.to(device) + ... return self + >>> isinstance(CustomObject(), TransferableDataType) + True + """ + + @classmethod + def __subclasshook__(cls, subclass): + if cls is TransferableDataType: + to = getattr(subclass, "to", None) + return callable(to) + return NotImplemented + + +def move_data_to_device(batch: Any, device: torch.device): + """ + Transfers a collection of data to the given device. Any object that defines a method + ``to(device)`` will be moved and all other objects in the collection will be left untouched. + + Args: + batch: A tensor or collection of tensors or anything that has a method `.to(...)`. + See :func:`apply_to_collection` for a list of supported collection types. + device: The device to which the data should be moved + + Return: + the same collection but with all contained tensors residing on the new device. + + See Also: + - :meth:`torch.Tensor.to` + - :class:`torch.device` + """ + + def batch_to(data): + # try to move torchtext data first + if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): + + # Shallow copy because each Batch has a reference to Dataset which contains all examples + device_data = copy(data) + for field, field_value in data.dataset.fields.items(): + if field_value is None: + continue + device_field = move_data_to_device(getattr(data, field), device) + setattr(device_data, field, device_field) + return device_data + + kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} + return data.to(device, **kwargs) + + dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType + return apply_to_collection(batch, dtype=dtype, function=batch_to) + + +def convert_to_tensors(data, device: torch.device = None): + if device is None: + raise MisconfigurationException("device (torch.device) should be provided.") + + for src_dtype, conversion_func in CONVERSION_DTYPES: + data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) + + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device): + return t.to(device).contiguous() + + data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device)) + return data diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py new file mode 100644 index 00000000000000..46d88184ee1904 --- /dev/null +++ b/pytorch_lightning/utilities/argparse.py @@ -0,0 +1,298 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from argparse import _ArgumentGroup, ArgumentParser, Namespace +from contextlib import suppress +from typing import Any, Dict, List, Tuple, Union + +from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str + + +def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + """Create an instance from CLI arguments. + Eventually use varibles from OS environement which are defined as "PL__" + + Args: + cls: Lightning class + args: The parser or namespace to take arguments from. Only known arguments will be + parsed and passed to the :class:`Trainer`. + **kwargs: Additional keyword arguments that may override ones in the parser or namespace. + These must be valid Trainer arguments. + + Example: + >>> from pytorch_lightning import Trainer + >>> parser = ArgumentParser(add_help=False) + >>> parser = Trainer.add_argparse_args(parser) + >>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP + >>> args = Trainer.parse_argparser(parser.parse_args("")) + >>> trainer = Trainer.from_argparse_args(args, logger=False) + """ + if isinstance(args, ArgumentParser): + args = cls.parse_argparser(args) + + params = vars(args) + + # we only want to pass in valid Trainer args, the rest may be user specific + valid_kwargs = inspect.signature(cls.__init__).parameters + trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) + trainer_kwargs.update(**kwargs) + + return cls(**trainer_kwargs) + + +def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + """Parse CLI arguments, required for custom bool types.""" + args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser + + types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)} + + modified_args = {} + for k, v in vars(args).items(): + if k in types_default and v is None: + # We need to figure out if the None is due to using nargs="?" or if it comes from the default value + arg_types, arg_default = types_default[k] + if bool in arg_types and isinstance(arg_default, bool): + # Value has been passed as a flag => It is currently None, so we need to set it to True + # We always set to True, regardless of the default value. + # Users must pass False directly, but when passing nothing True is assumed. + # i.e. the only way to disable something that defaults to True is to use the long form: + # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, + # which then becomes True here. + + v = True + + modified_args[k] = v + return Namespace(**modified_args) + + +def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: + """Parse environment arguments if they are defined. + + Example: + >>> from pytorch_lightning import Trainer + >>> parse_env_variables(Trainer) + Namespace() + >>> import os + >>> os.environ["PL_TRAINER_GPUS"] = '42' + >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' + >>> parse_env_variables(Trainer) + Namespace(gpus=42) + >>> del os.environ["PL_TRAINER_GPUS"] + """ + cls_arg_defaults = get_init_arguments_and_types(cls) + + env_args = {} + for arg_name, _, _ in cls_arg_defaults: + env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()} + val = os.environ.get(env) + if not (val is None or val == ''): + # todo: specify the possible exception + with suppress(Exception): + # converting to native types like int/float/bool + val = eval(val) + env_args[arg_name] = val + return Namespace(**env_args) + + +def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: + r"""Scans the class signature and returns argument names, types and default values. + + Returns: + List with tuples of 3 values: + (argument name, set with argument types, argument default value). + + Examples: + + >>> from pytorch_lightning import Trainer + >>> args = get_init_arguments_and_types(Trainer) + + """ + cls_default_params = inspect.signature(cls).parameters + name_type_default = [] + for arg in cls_default_params: + arg_type = cls_default_params[arg].annotation + arg_default = cls_default_params[arg].default + try: + arg_types = tuple(arg_type.__args__) + except AttributeError: + arg_types = (arg_type, ) + + name_type_default.append((arg, arg_types, arg_default)) + + return name_type_default + + +def get_abbrev_qualified_cls_name(cls): + assert isinstance(cls, type), repr(cls) + if cls.__module__.startswith("pytorch_lightning."): + # Abbreviate. + return f"pl.{cls.__name__}" + else: + # Fully qualified. + return f"{cls.__module__}.{cls.__qualname__}" + + +def add_argparse_args( + cls, + parent_parser: ArgumentParser, + *, + use_argument_group=True, +) -> ArgumentParser: + r"""Extends existing argparse by default attributes for ``cls``. + + Args: + cls: Lightning class + parent_parser: + The custom cli arguments parser, which will be extended by + the class's default arguments. + use_argument_group: + By default, this is True, and uses ``add_argument_group`` to add + a new group. + If False, this will use old behavior. + + Returns: + If use_argument_group is True, returns ``parent_parser`` to keep old + workflows. If False, will return the new ArgumentParser object. + + Only arguments of the allowed types (str, float, int, bool) will + extend the ``parent_parser``. + + Examples: + + # Option 1: Default usage. + >>> import argparse + >>> from pytorch_lightning import Trainer + >>> parser = argparse.ArgumentParser() + >>> parser = Trainer.add_argparse_args(parser) + >>> args = parser.parse_args([]) + + # Option 2: Disable use_argument_group (old behavior). + >>> import argparse + >>> from pytorch_lightning import Trainer + >>> parser = argparse.ArgumentParser() + >>> parser = Trainer.add_argparse_args(parser, use_argument_group=False) + >>> args = parser.parse_args([]) + """ + if isinstance(parent_parser, _ArgumentGroup): + raise RuntimeError("Please only pass an ArgumentParser instance.") + if use_argument_group: + group_name = get_abbrev_qualified_cls_name(cls) + parser = parent_parser.add_argument_group(group_name) + else: + parser = ArgumentParser( + parents=[parent_parser], + add_help=False, + ) + + ignore_arg_names = ['self', 'args', 'kwargs'] + if hasattr(cls, "get_deprecated_arg_names"): + ignore_arg_names += cls.get_deprecated_arg_names() + + allowed_types = (str, int, float, bool) + + # Get symbols from cls or init function. + for symbol in (cls, cls.__init__): + args_and_types = get_init_arguments_and_types(symbol) + args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names] + if len(args_and_types) > 0: + break + + args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "") + + for arg, arg_types, arg_default in args_and_types: + arg_types = [at for at in allowed_types if at in arg_types] + if not arg_types: + # skip argument with not supported type + continue + arg_kwargs = {} + if bool in arg_types: + arg_kwargs.update(nargs="?", const=True) + # if the only arg type is bool + if len(arg_types) == 1: + use_type = str_to_bool + elif str in arg_types: + use_type = str_to_bool_or_str + else: + # filter out the bool as we need to use more general + use_type = [at for at in arg_types if at is not bool][0] + else: + use_type = arg_types[0] + + if arg == 'gpus' or arg == 'tpu_cores': + use_type = _gpus_allowed_type + arg_default = _gpus_arg_default + + # hack for types in (int, float) + if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): + use_type = _int_or_float_type + + # hack for track_grad_norm + if arg == 'track_grad_norm': + use_type = float + + parser.add_argument( + f'--{arg}', + dest=arg, + default=arg_default, + type=use_type, + help=args_help.get(arg), + **arg_kwargs, + ) + + if use_argument_group: + return parent_parser + else: + return parser + + +def parse_args_from_docstring(docstring: str) -> Dict[str, str]: + arg_block_indent = None + current_arg = None + parsed = {} + for line in docstring.split("\n"): + stripped = line.lstrip() + if not stripped: + continue + line_indent = len(line) - len(stripped) + if stripped.startswith(('Args:', 'Arguments:', 'Parameters:')): + arg_block_indent = line_indent + 4 + elif arg_block_indent is None: + continue + elif line_indent < arg_block_indent: + break + elif line_indent == arg_block_indent: + current_arg, arg_description = stripped.split(':', maxsplit=1) + parsed[current_arg] = arg_description.lstrip() + elif line_indent > arg_block_indent: + parsed[current_arg] += f' {stripped}' + return parsed + + +def _gpus_allowed_type(x) -> Union[int, str]: + if ',' in x: + return str(x) + else: + return int(x) + + +def _gpus_arg_default(x) -> Union[int, str]: + return _gpus_allowed_type(x) + + +def _int_or_float_type(x) -> Union[int, float]: + if '.' in str(x): + return float(x) + else: + return int(x) diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py new file mode 100644 index 00000000000000..80db2429f7d2af --- /dev/null +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -0,0 +1,5 @@ +from pytorch_lightning.utilities import rank_zero_deprecation + +rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4") + +from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py new file mode 100644 index 00000000000000..e94934020107d4 --- /dev/null +++ b/pytorch_lightning/utilities/cloud_io.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from distutils.version import LooseVersion +from pathlib import Path +from typing import IO, Union + +import fsspec +import torch + + +def load(path_or_url: Union[str, IO, Path], map_location=None): + if not isinstance(path_or_url, (str, Path)): + # any sort of BytesIO or similiar + return torch.load(path_or_url, map_location=map_location) + if str(path_or_url).startswith("http"): + return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location) + fs = get_filesystem(path_or_url) + with fs.open(path_or_url, "rb") as f: + return torch.load(f, map_location=map_location) + + +def get_filesystem(path: Union[str, Path]): + path = str(path) + if "://" in path: + # use the fileystem from the protocol specified + return fsspec.filesystem(path.split(":", 1)[0]) + else: + # use local filesystem + return fsspec.filesystem("file") + + +def atomic_save(checkpoint, filepath: str): + """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. + + Args: + checkpoint: The object to save. + Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` + accepts. + filepath: The path to which the checkpoint will be saved. + This points to the file that the checkpoint will be stored in. + """ + + bytesbuffer = io.BytesIO() + # Can't use the new zipfile serialization for 1.6.0 because there's a bug in + # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. + # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 + if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: + torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False) + else: + torch.save(checkpoint, bytesbuffer) + with fsspec.open(filepath, "wb") as f: + f.write(bytesbuffer.getvalue()) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py new file mode 100644 index 00000000000000..a73299e2af77b9 --- /dev/null +++ b/pytorch_lightning/utilities/data.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from torch.utils.data import DataLoader, IterableDataset + +from pytorch_lightning.utilities import rank_zero_warn + + +def has_iterable_dataset(dataloader: DataLoader): + return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset) + + +def has_len(dataloader: DataLoader) -> bool: + """ Checks if a given Dataloader has __len__ method implemented i.e. if + it is a finite dataloader or infinite dataloader. """ + + try: + # try getting the length + if len(dataloader) == 0: + raise ValueError('`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch') + has_len = True + except TypeError: + has_len = False + except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used + has_len = False + + if has_len and has_iterable_dataset(dataloader): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len + + +def get_len(dataloader: DataLoader) -> Union[int, float]: + """ Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf'). """ + + if has_len(dataloader): + return len(dataloader) + + return float('inf') diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py new file mode 100644 index 00000000000000..56833fd03735ad --- /dev/null +++ b/pytorch_lightning/utilities/debugging.py @@ -0,0 +1,202 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from collections import Counter +from functools import wraps +from typing import Any, Callable, Optional + + +def enabled_only(fn: Callable): + """Decorate a logger method to run it only on the process with rank 0. + + Args: + fn: Function to decorate + """ + + @wraps(fn) + def wrapped_fn(self, *args, **kwargs): + if self.enabled: + fn(self, *args, **kwargs) + + return wrapped_fn + + +class InternalDebugger(object): + + def __init__(self, trainer): + self.enabled = os.environ.get('PL_DEV_DEBUG', '0') == '1' + self.trainer = trainer + self.logged_metrics = [] + self.pbar_added_metrics = [] + self.saved_train_losses = [] + self.saved_val_losses = [] + self.saved_test_losses = [] + self.early_stopping_history = [] + self.checkpoint_callback_history = [] + self.events = [] + self.saved_lr_scheduler_updates = [] + self.train_dataloader_calls = [] + self.val_dataloader_calls = [] + self.test_dataloader_calls = [] + self.dataloader_sequence_calls = [] + + def track_event( + self, + evt_type: str, + evt_value: Any = None, + global_rank: Optional[int] = None, + local_rank: Optional[int] = None, + comment: str = '' + ) -> None: + self.events.append({ + "timestamp": time.time(), + "event": evt_type, + "value": evt_value, + "global_rank": global_rank, + "local_rank": local_rank, + "comment": comment, + }) + + def count_events(self, evt_type: str, strict=False) -> int: + count = 0 + for evt in self.events: + if strict and evt["event"] == evt_type: + count += 1 + elif not strict and evt_type in evt["event"]: + count += 1 + return count + + @enabled_only + def track_load_dataloader_call(self, name, dataloaders): + loader_counts = len(dataloaders) + + lengths = [] + for dl in dataloaders: + try: + length = len(dl) + # todo: specify the possible exception + except Exception: + length = -1 + lengths.append(length) + + values = { + 'global_step': self.trainer.global_step, + 'epoch': self.trainer.current_epoch, + 'num_loaders': loader_counts, + 'lengths': lengths, + 'name': name + } + + # track the sequence in case we need to verify the sequence + self.dataloader_sequence_calls.append(values) + + if 'train' in name: + self.train_dataloader_calls.append(values) + elif 'val' in name: + self.val_dataloader_calls.append(values) + elif 'test' in name: + self.test_dataloader_calls.append(values) + + @enabled_only + def track_logged_metrics_history(self, scalar_metrics): + scalar_metrics['global_step'] = self.trainer.global_step + self.logged_metrics.append(scalar_metrics) + + @enabled_only + def track_train_loss_history(self, batch_idx, loss): + loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()} + self.saved_train_losses.append(loss_dict) + + @enabled_only + def track_lr_schedulers_update( + self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None + ): + loss_dict = { + 'batch_idx': batch_idx, + 'interval': interval, + 'scheduler_idx': scheduler_idx, + 'epoch': self.trainer.current_epoch, + 'monitor_key': monitor_key, + 'monitor_val': monitor_val, + 'old_lr': old_lr, + 'new_lr': new_lr + } + self.saved_lr_scheduler_updates.append(loss_dict) + + @enabled_only + def track_eval_loss_history(self, batch_idx, dataloader_idx, output): + loss_dict = { + 'sanity_check': self.trainer.sanity_checking, + 'dataloader_idx': dataloader_idx, + 'batch_idx': batch_idx, + 'epoch': self.trainer.current_epoch, + 'output': output + } + + if self.trainer.testing: + self.saved_test_losses.append(loss_dict) + else: + self.saved_val_losses.append(loss_dict) + + @enabled_only + def track_pbar_metrics_history(self, metrics): + metrics['debug_epoch'] = self.trainer.current_epoch + self.pbar_added_metrics.append(metrics) + + @enabled_only + def track_early_stopping_history(self, callback, current): + debug_dict = { + 'epoch': self.trainer.current_epoch, + 'global_step': self.trainer.global_step, + 'rank': self.trainer.global_rank, + 'current': current, + 'best': callback.best_score, + 'patience': callback.wait_count + } + self.early_stopping_history.append(debug_dict) + + @enabled_only + def track_checkpointing_history(self, filepath): + cb = self.trainer.checkpoint_callback + debug_dict = { + 'epoch': self.trainer.current_epoch, + 'global_step': self.trainer.global_step, + 'monitor': cb.monitor, + 'rank': self.trainer.global_rank, + 'filepath': filepath + } + self.checkpoint_callback_history.append(debug_dict) + + @property + def num_seen_sanity_check_batches(self): + count = len([x for x in self.saved_val_losses if x['sanity_check']]) + return count + + @property + def num_seen_val_check_batches(self): + counts = Counter() + for x in self.saved_val_losses: + if not x['sanity_check']: + counts.update({x['dataloader_idx']: 1}) + return counts + + @property + def num_seen_test_check_batches(self): + counts = Counter() + for x in self.saved_test_losses: + if not x['sanity_check']: + counts.update({x['dataloader_idx']: 1}) + return counts diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py new file mode 100644 index 00000000000000..3e3eccc93b368c --- /dev/null +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -0,0 +1,197 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from torch.nn import Module + +from pytorch_lightning.core.decorators import parameter_validation + + +class DeviceDtypeModuleMixin(Module): + __jit_unused_properties__ = ['device', 'dtype'] + + def __init__(self): + super().__init__() + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') + + @property + def dtype(self) -> Union[str, torch.dtype]: + return self._dtype + + @dtype.setter + def dtype(self, new_dtype: Union[str, torch.dtype]): + # necessary to avoid infinite recursion + raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') + + @property + def device(self) -> Union[str, torch.device]: + device = self._device + + # make this more explicit to always include the index + if device.type == 'cuda' and device.index is None: + return torch.device(f'cuda:{torch.cuda.current_device()}') + + return device + + @device.setter + def device(self, new_device: Union[str, torch.device]): + # Necessary to avoid infinite recursion + raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') + + @parameter_validation + def to(self, *args, **kwargs) -> Module: + """Moves and/or casts the parameters and buffers. + + This can be called as + .. function:: to(device=None, dtype=None, non_blocking=False) + .. function:: to(dtype, non_blocking=False) + .. function:: to(tensor, non_blocking=False) + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point desired :attr:`dtype` s. In addition, this method will + only cast the floating point parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + See below for examples. + + Note: + This method modifies the module in-place. + + Args: + device: the desired device of the parameters + and buffers in this module + dtype: the desired floating point type of + the floating point parameters and buffers in this module + tensor: Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + + Returns: + Module: self + + Example:: + >>> class ExampleModule(DeviceDtypeModuleMixin): + ... def __init__(self, weight: torch.Tensor): + ... super().__init__() + ... self.register_buffer('weight', weight) + ... + ... def on_post_move_to_device(self): + ... pass + >>> _ = torch.manual_seed(0) + >>> module = ExampleModule(torch.rand(3, 4)) + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]]) + >>> module.to(torch.double) + ExampleModule() + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float64) + >>> cpu = torch.device('cpu') + >>> module.to(cpu, dtype=torch.half, non_blocking=True) + ExampleModule() + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + >>> module.to(cpu) + ExampleModule() + >>> module.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + >>> module.device + device(type='cpu') + >>> module.dtype + torch.float16 + """ + # there is diff nb vars in PT 1.5 + out = torch._C._nn._parse_to(*args, **kwargs) + self.__update_properties(device=out[0], dtype=out[1]) + return super().to(*args, **kwargs) + + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module: + """Moves all model parameters and buffers to the GPU. + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + Arguments: + device: if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device) + self.__update_properties(device=property_device) + return super().cuda(device=device) + + def cpu(self) -> Module: + """Moves all model parameters and buffers to the CPU. + + Returns: + Module: self + """ + self.__update_properties(device=torch.device('cpu')) + return super().cpu() + + def type(self, dst_type: Union[str, torch.dtype]) -> Module: + """Casts all parameters and buffers to :attr:`dst_type`. + + Arguments: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + self.__update_properties(dtype=dst_type) + return super().type(dst_type=dst_type) + + def float(self) -> Module: + """Casts all floating point parameters and buffers to float datatype. + + Returns: + Module: self + """ + self.__update_properties(dtype=torch.float) + return super().float() + + def double(self) -> Module: + """Casts all floating point parameters and buffers to ``double`` datatype. + + Returns: + Module: self + """ + self.__update_properties(dtype=torch.double) + return super().double() + + def half(self) -> Module: + """Casts all floating point parameters and buffers to ``half`` datatype. + + Returns: + Module: self + """ + self.__update_properties(dtype=torch.half) + return super().half() + + def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): + + def apply_fn(module): + if not isinstance(module, DeviceDtypeModuleMixin): + return + if device is not None: + module._device = device + if dtype is not None: + module._dtype = dtype + + self.apply(apply_fn) diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py new file mode 100644 index 00000000000000..f20b978ebd8b60 --- /dev/null +++ b/pytorch_lightning/utilities/device_parser.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, MutableSequence, Optional, Tuple, Union + +import torch + +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def determine_root_gpu_device(gpus: List[int]) -> Optional[int]: + """ + Args: + gpus: non-empty list of ints representing which gpus to use + + Returns: + designated root GPU device id + """ + if gpus is None: + return None + + if not isinstance(gpus, list): + raise TypeError("gpus should be a list") + + assert len(gpus) > 0, "gpus should be a non empty list" + + # set root gpu + root_gpu = gpus[0] + + return root_gpu + + +def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]: + """ + Parses the GPU ids given in the format as accepted by the + :class:`~pytorch_lightning.trainer.Trainer`. + + Args: + gpus: An int -1 or string '-1' indicate that all available GPUs should be used. + A list of ints or a string containing list of comma separated integers + indicates specific GPUs to use. + An int 0 means that no GPUs should be used. + Any int N > 0 indicates that GPUs [0..N) should be used. + + Returns: + a list of gpus to be used or ``None`` if no GPUs were requested + + If no GPUs are available but the value of gpus variable indicates request for GPUs + then a MisconfigurationException is raised. + """ + + # nothing was passed into the GPUs argument + if callable(gpus): + return None + + # Check that gpus param is None, Int, String or List + _check_data_type(gpus) + + # Handle the case when no gpus are requested + if gpus is None or isinstance(gpus, int) and gpus == 0: + return None + + # We know user requested GPUs therefore if some of the + # requested GPUs are not available an exception is thrown. + + gpus = _normalize_parse_gpu_string_input(gpus) + gpus = _normalize_parse_gpu_input_to_list(gpus) + if not gpus: + raise MisconfigurationException("GPUs requested but none are available.") + gpus = _sanitize_gpu_ids(gpus) + + return gpus + + +def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int], int]]: + """ + Parses the tpu_cores given in the format as accepted by the + :class:`~pytorch_lightning.trainer.Trainer`. + + Args: + tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used + An int 8 or string '8' indicate that all 8 cores with multi-processing should be used + A list of int or a string containing list of comma separated integer + indicates specific TPU core to use. + + Returns: + a list of tpu_cores to be used or ``None`` if no TPU cores were requested + """ + + if callable(tpu_cores): + return None + + _check_data_type(tpu_cores) + + if isinstance(tpu_cores, str): + tpu_cores = _parse_tpu_cores_str(tpu_cores.strip()) + + if not _tpu_cores_valid(tpu_cores): + raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]") + + if tpu_cores is not None and not _TPU_AVAILABLE: + raise MisconfigurationException('No TPU devices were found.') + + return tpu_cores + + +def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: + if isinstance(s, str): + if s == '-1': + return -1 + else: + return [int(x.strip()) for x in s.split(',') if len(x) > 0] + else: + return s + + +def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: + """ + Checks that each of the GPUs in the list is actually available. + Raises a MisconfigurationException if any of the GPUs is not available. + + Args: + gpus: list of ints corresponding to GPU indices + + Returns: + unmodified gpus variable + """ + all_available_gpus = _get_all_available_gpus() + for gpu in gpus: + if gpu not in all_available_gpus: + raise MisconfigurationException( + f"You requested GPUs: {gpus}\n But your machine only has: {all_available_gpus}" + ) + return gpus + + +def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int], Tuple[int, ...]]) -> Optional[List[int]]: + assert gpus is not None + if isinstance(gpus, (MutableSequence, tuple)): + return list(gpus) + + # must be an int + if not gpus: # gpus==0 + return None + if gpus == -1: + return _get_all_available_gpus() + + return list(range(gpus)) + + +def _get_all_available_gpus() -> List[int]: + """ + Returns: + a list of all available gpus + """ + return list(range(torch.cuda.device_count())) + + +def _check_data_type(device_ids: Any) -> None: + """ + Checks that the device_ids argument is one of: None, Int, String or List. + Raises a MisconfigurationException otherwise. + + Args: + device_ids: gpus/tpu_cores parameter as passed to the Trainer + """ + if device_ids is not None and \ + (not isinstance(device_ids, (int, str, MutableSequence, tuple)) or isinstance(device_ids, bool)): + raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.") + + +def _tpu_cores_valid(tpu_cores): + # allow 1 or 8 cores + if tpu_cores in (1, 8, None): + return True + + # allow picking 1 of 8 indexes + if isinstance(tpu_cores, (list, tuple, set)): + has_1_tpu_idx = len(tpu_cores) == 1 + is_valid_tpu_idx = tpu_cores[0] in range(1, 9) + + is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx + return is_valid_tpu_core_choice + + return False + + +def _parse_tpu_cores_str(tpu_cores): + if tpu_cores in ('1', '8'): + tpu_cores = int(tpu_cores) + else: + tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0] + return tpu_cores diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index f4c942b9beaaac..bf7a199fc08dcf 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -1,5 +1,37 @@ -from functools import wraps +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os import warnings +from functools import partial, wraps +from typing import Any, Optional, Union + +import torch + +log = logging.getLogger(__name__) + +if torch.distributed.is_available(): + from torch.distributed import group, ReduceOp + +else: + + class ReduceOp: + SUM = None + + class group: + WORLD = None def rank_zero_only(fn): @@ -12,15 +44,157 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn -try: - # add the attribute to the function but don't overwrite in case Trainer has already set it - getattr(rank_zero_only, 'rank') -except AttributeError: - rank_zero_only.rank = 0 +# add the attribute to the function but don't overwrite in case Trainer has already set it +rank_zero_only.rank = getattr(rank_zero_only, 'rank', int(os.environ.get('LOCAL_RANK', 0))) def _warn(*args, **kwargs): warnings.warn(*args, **kwargs) +def _info(*args, **kwargs): + log.info(*args, **kwargs) + + +def _debug(*args, **kwargs): + log.debug(*args, **kwargs) + + +rank_zero_debug = rank_zero_only(_debug) +rank_zero_info = rank_zero_only(_info) rank_zero_warn = rank_zero_only(_warn) +rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) + + +def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): + """ + Function to gather all tensors from several ddp processes onto a list that + is broadcasted to all processes + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + """ + if group is None: + group = torch.distributed.group.WORLD + + # convert tensors to contiguous format + result = result.contiguous() + + world_size = torch.distributed.get_world_size(group) + + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) + + return gathered_result + + +def sync_ddp_if_available( + result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None +) -> torch.Tensor: + """ + Function to reduce a tensor across worker processes during distributed training + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + + Return: + reduced value + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + +def sync_ddp( + result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None +) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + + Return: + reduced value + """ + divide_by_world_size = False + + if group is None: + group = torch.distributed.group.WORLD + + op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + divide_by_world_size = True + + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=op, group=group, async_op=False) + + if divide_by_world_size: + result = result / torch.distributed.get_world_size(group) + + return result + + +class AllGatherGrad(torch.autograd.Function): + + @staticmethod + def forward(ctx, tensor, group=group.WORLD): + ctx.group = group + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor, group=group) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx, *grad_output): + grad_output = torch.cat(grad_output) + + torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) + + return grad_output[torch.distributed.get_rank()], None + + +def all_gather_ddp_if_available( + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False +) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + group = group if group is not None else torch.distributed.group.WORLD + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if sync_grads: + return AllGatherGrad.apply(tensor, group) + else: + with torch.no_grad(): + return AllGatherGrad.apply(tensor, group) + return tensor diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py new file mode 100644 index 00000000000000..169481fa63e67b --- /dev/null +++ b/pytorch_lightning/utilities/enums.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Enumerated utilities""" +from enum import Enum +from typing import List, Optional, Union + + +class LightningEnum(str, Enum): + """ Type of any enumerator with allowed comparison to string invariant to cases. """ + + @classmethod + def from_str(cls, value: str) -> Optional['LightningEnum']: + statuses = [status for status in dir(cls) if not status.startswith('_')] + for st in statuses: + if st.lower() == value.lower(): + return getattr(cls, st) + return None + + def __eq__(self, other: Union[str, Enum]) -> bool: + other = other.value if isinstance(other, Enum) else str(other) + return self.value.lower() == other.lower() + + def __hash__(self) -> int: + # re-enable hashtable so it can be used as a dict key or in a set + # example: set(LightningEnum) + return hash(self.name) + + +class AMPType(LightningEnum): + """Type of Automatic Mixed Precission used for training. + + >>> # you can math the type with string + >>> AMPType.APEX == 'apex' + True + """ + APEX = 'apex' + NATIVE = 'native' + + +class DistributedType(LightningEnum): + """ Define type of ditributed computing. + + >>> # you can math the type with string + >>> DistributedType.DDP == 'ddp' + True + >>> # which is case invariant + >>> DistributedType.DDP2 in ('ddp2', ) + True + """ + + @staticmethod + def interactive_compatible_types() -> List['DistributedType']: + """Returns a list containing interactive compatible DistributeTypes""" + return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN] + + def is_interactive_compatible(self) -> bool: + """Returns whether self is interactive compatible""" + return self in DistributedType.interactive_compatible_types() + + DP = 'dp' + DDP = 'ddp' + DDP2 = 'ddp2' + DDP_SPAWN = 'ddp_spawn' + DEEPSPEED = 'deepspeed' + HOROVOD = 'horovod' + DDP_SHARDED = 'ddp_sharded' + DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' + RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' + + +class DeviceType(LightningEnum): + """ Define Device type byt its nature - acceleatrors. + + >>> DeviceType.CPU == DeviceType.from_str('cpu') + True + >>> # you can math the type with string + >>> DeviceType.GPU == 'GPU' + True + >>> # which is case invariant + >>> DeviceType.TPU in ('tpu', 'CPU') + True + """ + CPU = 'CPU' + GPU = 'GPU' + TPU = 'TPU' diff --git a/pytorch_lightning/utilities/exceptions.py b/pytorch_lightning/utilities/exceptions.py index b7f92bca2679a5..01b1e8c053950c 100644 --- a/pytorch_lightning/utilities/exceptions.py +++ b/pytorch_lightning/utilities/exceptions.py @@ -1,2 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + class MisconfigurationException(Exception): pass diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py new file mode 100644 index 00000000000000..baeac9be572184 --- /dev/null +++ b/pytorch_lightning/utilities/imports.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""General utilities""" +import importlib +import operator +import platform +import sys +from distutils.version import LooseVersion +from importlib.util import find_spec + +import torch +from pkg_resources import DistributionNotFound + + +def _module_available(module_path: str) -> bool: + """ + Check if a path is available in your environment + + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + try: + return find_spec(module_path) is not None + except AttributeError: + # Python 3.6 + return False + except ModuleNotFoundError: + # Python 3.7+ + return False + + +def _compare_version(package: str, op, version) -> bool: + """ + Compare package version with some requirements + + >>> _compare_version("torch", operator.ge, "0.1") + True + """ + try: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: + return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) + + +_IS_WINDOWS = platform.system() == "Windows" +_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765 +_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") +_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") +_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") +_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") + +_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False +_APEX_AVAILABLE = _module_available("apex.amp") +_BOLTS_AVAILABLE = _module_available('pl_bolts') +_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') +_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") +_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') +_HOROVOD_AVAILABLE = _module_available("horovod.torch") +_HYDRA_AVAILABLE = _module_available("hydra") +_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") +_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +_OMEGACONF_AVAILABLE = _module_available("omegaconf") +_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc') +_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) +_TORCHTEXT_AVAILABLE = _module_available("torchtext") +_TORCHVISION_AVAILABLE = _module_available('torchvision') +_XLA_AVAILABLE = _module_available("torch_xla") diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 148d28ca143ff6..d67739c3b3fc2d 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -1,4 +1,23 @@ -def recursive_detach(in_dict: dict) -> dict: +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import torch + + +def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries @@ -6,17 +25,55 @@ def recursive_detach(in_dict: dict) -> dict: not affected by this utility function. Args: - in_dict: + in_dict: Dictionary with tensors to detach + to_cpu: Whether to move tensor to cpu Return: - out_dict: + out_dict: Dictionary with detached tensors """ out_dict = {} for k, v in in_dict.items(): if isinstance(v, dict): - out_dict.update({k: recursive_detach(v)}) + v = recursive_detach(v, to_cpu=to_cpu) elif callable(getattr(v, 'detach', None)): - out_dict.update({k: v.detach()}) - else: - out_dict.update({k: v}) + v = v.detach() + if to_cpu: + v = v.cpu() + out_dict[k] = v return out_dict + + +def is_oom_error(exception): + return is_cuda_out_of_memory(exception) \ + or is_cudnn_snafu(exception) \ + or is_out_of_cpu_memory(exception) + + +# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py +def is_cuda_out_of_memory(exception): + return isinstance(exception, RuntimeError) \ + and len(exception.args) == 1 \ + and "CUDA out of memory." in exception.args[0] + + +# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py +def is_cudnn_snafu(exception): + # For/because of https://github.com/pytorch/pytorch/issues/4107 + return isinstance(exception, RuntimeError) \ + and len(exception.args) == 1 \ + and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] + + +# based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py +def is_out_of_cpu_memory(exception): + return isinstance(exception, RuntimeError) \ + and len(exception.args) == 1 \ + and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] + + +# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py +def garbage_collection_cuda(): + """Garbage collection Torch (CUDA) memory.""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py new file mode 100644 index 00000000000000..87bd9e6c4545d7 --- /dev/null +++ b/pytorch_lightning/utilities/model_helpers.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.lightning import LightningModule + + +def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool: + # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super + # TODO - refector this function to accept model_name, instance, parent so it makes more sense + super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule + + if not hasattr(model, method_name) or not hasattr(super_object, method_name): + # in case of calling deprecated method + return False + + instance_attr = getattr(model, method_name) + if not instance_attr: + return False + super_attr = getattr(super_object, method_name) + + # when code pointers are different, it was implemented + if hasattr(instance_attr, 'patch_loader_code'): + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overridden = instance_attr.__code__ is not super_attr.__code__ + return is_overridden diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py new file mode 100644 index 00000000000000..728f73f4f0d327 --- /dev/null +++ b/pytorch_lightning/utilities/model_utils.py @@ -0,0 +1,7 @@ +from pytorch_lightning.utilities import rank_zero_deprecation + +rank_zero_deprecation( + "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4" +) + +from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 2549e485d55ef1..dd12f34cfe9269 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -1,49 +1,289 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import pickle from argparse import Namespace +from typing import Dict, Tuple, Union + +from pytorch_lightning.utilities import rank_zero_warn + + +def str_to_bool_or_str(val: str) -> Union[str, bool]: + """Possibly convert a string representation of truth to bool. + Returns the input otherwise. + Based on the python implementation distutils.utils.strtobool + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. + """ + lower = val.lower() + if lower in ('y', 'yes', 't', 'true', 'on', '1'): + return True + elif lower in ('n', 'no', 'f', 'false', 'off', '0'): + return False + else: + return val -def strtobool(val): - """Convert a string representation of truth to true (1) or false (0). - Copied from the python implementation distutils.utils.strtobool +def str_to_bool(val: str) -> bool: + """Convert a string representation of truth to bool. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 'val' is anything else. - >>> strtobool('YES') - 1 - >>> strtobool('FALSE') - 0 - """ - val = val.lower() - if val in ('y', 'yes', 't', 'true', 'on', '1'): - return 1 - elif val in ('n', 'no', 'f', 'false', 'off', '0'): - return 0 - else: - raise ValueError(f'invalid truth value {val}') + >>> str_to_bool('YES') + True + >>> str_to_bool('FALSE') + False + """ + val = str_to_bool_or_str(val) + if isinstance(val, bool): + return val + raise ValueError(f'invalid truth value {val}') + + +def is_picklable(obj: object) -> bool: + """Tests if an object can be pickled""" + + try: + pickle.dumps(obj) + return True + except (pickle.PicklingError, AttributeError): + return False def clean_namespace(hparams): + """Removes all unpicklable entries from hparams""" + + hparams_dict = hparams + if isinstance(hparams, Namespace): + hparams_dict = hparams.__dict__ + + del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)] + + for k in del_attrs: + rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning) + del hparams_dict[k] + + +def parse_class_init_keys(cls) -> Tuple[str, str, str]: + """Parse key words for standard self, *args and **kwargs + + >>> class Model(): + ... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): + ... pass + >>> parse_class_init_keys(Model) + ('self', 'my_args', 'my_kwargs') + """ + init_parameters = inspect.signature(cls.__init__).parameters + # docs claims the params are always ordered + # https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + init_params = list(init_parameters.values()) + # self is always first + n_self = init_params[0].name + + def _get_first_if_any(params, param_type): + for p in params: + if p.kind == param_type: + return p.name + + n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL) + n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD) + + return n_self, n_args, n_kwargs + + +def get_init_args(frame) -> dict: + _, _, _, local_vars = inspect.getargvalues(frame) + if '__class__' not in local_vars: + return {} + cls = local_vars['__class__'] + init_parameters = inspect.signature(cls.__init__).parameters + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n] + exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') + # only collect variables that appear in the signature + local_args = {k: local_vars[k] for k in init_parameters.keys()} + local_args.update(local_args.get(kwargs_var, {})) + local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} + return local_args + + +def collect_init_args(frame, path_args: list, inside: bool = False) -> list: """ - Removes all functions from hparams so we can pickle - :param hparams: - :return: + Recursively collects the arguments passed to the child constructors in the inheritance tree. + + Args: + frame: the current stack frame + path_args: a list of dictionaries containing the constructor args in all parent classes + inside: track if we are inside inheritance path, avoid terminating too soon + + Return: + A list of dictionaries where each dictionary contains the arguments passed to the + constructor at that level. The last entry corresponds to the constructor call of the + most specific class in the hierarchy. """ + _, _, _, local_vars = inspect.getargvalues(frame) + if '__class__' in local_vars: + local_args = get_init_args(frame) + # recursive update + path_args.append(local_args) + return collect_init_args(frame.f_back, path_args, inside=True) + elif not inside: + return collect_init_args(frame.f_back, path_args, inside) + else: + return path_args - if isinstance(hparams, Namespace): - del_attrs = [] - for k in hparams.__dict__: - if callable(getattr(hparams, k)): - del_attrs.append(k) - - for k in del_attrs: - delattr(hparams, k) - - elif isinstance(hparams, dict): - del_attrs = [] - for k, v in hparams.items(): - if callable(v): - del_attrs.append(k) - - for k in del_attrs: - del hparams[k] + +def flatten_dict(source, result=None): + if result is None: + result = {} + + for k, v in source.items(): + if isinstance(v, dict): + _ = flatten_dict(v, result) + else: + result[k] = v + + return result + + +class AttributeDict(Dict): + """Extended dictionary accesisable with dot notation. + + >>> ad = AttributeDict({'key1': 1, 'key2': 'abc'}) + >>> ad.key1 + 1 + >>> ad.update({'my-key': 3.14}) + >>> ad.update(mew_key=42) + >>> ad.key1 = 2 + >>> ad + "key1": 2 + "key2": abc + "mew_key": 42 + "my-key": 3.14 + """ + + def __getattr__(self, key): + try: + return self[key] + except KeyError as exp: + raise AttributeError(f'Missing attribute "{key}"') from exp + + def __setattr__(self, key, val): + self[key] = val + + def __repr__(self): + if not len(self): + return "" + max_key_length = max([len(str(k)) for k in self]) + tmp_name = '{:' + str(max_key_length + 3) + 's} {}' + rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())] + out = '\n'.join(rows) + return out + + +def _lightning_get_all_attr_holders(model, attribute): + """ + Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. + Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. + """ + trainer = getattr(model, 'trainer', None) + + holders = [] + + # Check if attribute in model + if hasattr(model, attribute): + holders.append(model) + + # Check if attribute in model.hparams, either namespace or dict + if hasattr(model, 'hparams'): + if attribute in model.hparams: + holders.append(model.hparams) + + # Check if the attribute in datamodule (datamodule gets registered in Trainer) + if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + holders.append(trainer.datamodule) + + return holders + + +def _lightning_get_first_attr_holder(model, attribute): + """ + Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. + Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, + returns the last one that has it. + """ + holders = _lightning_get_all_attr_holders(model, attribute) + if len(holders) == 0: + return None + # using the last holder to preserve backwards compatibility + return holders[-1] + + +def lightning_hasattr(model, attribute): + """ + Special hasattr for Lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. + """ + return _lightning_get_first_attr_holder(model, attribute) is not None + + +def lightning_getattr(model, attribute): + """ + Special getattr for Lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. + + Raises: + AttributeError: + If ``model`` doesn't have ``attribute`` in any of + model namespace, the hparams namespace/dict, and the datamodule. + """ + holder = _lightning_get_first_attr_holder(model, attribute) + if holder is None: + raise AttributeError( + f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.' + ) + + if isinstance(holder, dict): + return holder[attribute] + return getattr(holder, attribute) + + +def lightning_setattr(model, attribute, value): + """ + Special setattr for Lightning. Checks for attribute in model namespace + and the old hparams namespace/dict. + Will also set the attribute on datamodule, if it exists. + + Raises: + AttributeError: + If ``model`` doesn't have ``attribute`` in any of + model namespace, the hparams namespace/dict, and the datamodule. + """ + holders = _lightning_get_all_attr_holders(model, attribute) + if len(holders) == 0: + raise AttributeError( + f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.' + ) + + for holder in holders: + if isinstance(holder, dict): + holder[attribute] = value + else: + setattr(holder, attribute, value) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py new file mode 100644 index 00000000000000..8129075f99f4da --- /dev/null +++ b/pytorch_lightning/utilities/seed.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions to help with reproducibility of models. """ + +import logging +import os +import random +from typing import Optional + +import numpy as np +import torch + +from pytorch_lightning.utilities import rank_zero_warn + +log = logging.getLogger(__name__) + + +def seed_everything(seed: Optional[int] = None) -> int: + """ + Function that sets seed for pseudo-random number generators in: + pytorch, numpy, python.random + In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to + spawned subprocesses (e.g. ddp_spawn backend). + + Args: + seed: the integer value seed for global random state in Lightning. + If `None`, will read seed from `PL_GLOBAL_SEED` env variable + or select it randomly. + """ + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + + try: + if seed is None: + seed = os.environ.get("PL_GLOBAL_SEED") + seed = int(seed) + except (TypeError, ValueError): + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No correct seed found, seed set to {seed}") + + if not (min_seed_value <= seed <= max_seed_value): + rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") + seed = _select_seed_randomly(min_seed_value, max_seed_value) + + log.info(f"Global seed set to {seed}") + os.environ["PL_GLOBAL_SEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + return seed + + +def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: + return random.randint(min_seed_value, max_seed_value) diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py new file mode 100644 index 00000000000000..546d8e845ecb11 --- /dev/null +++ b/pytorch_lightning/utilities/signature_utils.py @@ -0,0 +1,22 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Callable + + +def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: + hook_params = list(inspect.signature(hook_fx).parameters) + if "args" in hook_params or param in hook_params: + return True + return False diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py new file mode 100644 index 00000000000000..4896845f102630 --- /dev/null +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +from shutil import copyfile + +import torch + +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint + +KEYS_MAPPING = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), +} + +log = logging.getLogger(__name__) + + +def upgrade_checkpoint(filepath): + checkpoint = torch.load(filepath) + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in KEYS_MAPPING.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + + torch.save(checkpoint, filepath) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Upgrade an old checkpoint to the current schema. \ + This will also save a backup of the original file." + ) + parser.add_argument("--file", help="filepath for a checkpoint to upgrade") + + args = parser.parse_args() + + log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") + copyfile(args.file, args.file + ".bak") + upgrade_checkpoint(args.file) diff --git a/pytorch_lightning/utilities/warning_utils.py b/pytorch_lightning/utilities/warning_utils.py new file mode 100644 index 00000000000000..0668bababa6090 --- /dev/null +++ b/pytorch_lightning/utilities/warning_utils.py @@ -0,0 +1,5 @@ +from pytorch_lightning.utilities import rank_zero_deprecation + +rank_zero_deprecation("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4") + +from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py new file mode 100644 index 00000000000000..a3dde95fa928f5 --- /dev/null +++ b/pytorch_lightning/utilities/warnings.py @@ -0,0 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities.distributed import rank_zero_warn + + +class WarningCache: + + def __init__(self): + self.warnings = set() + + def warn(self, m, *args, **kwargs): + if m not in self.warnings: + self.warnings.add(m) + rank_zero_warn(m, *args, **kwargs) + + def clear(self): + self.warnings.clear() diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py new file mode 100644 index 00000000000000..294d3d2c5ec40b --- /dev/null +++ b/pytorch_lightning/utilities/xla_device.py @@ -0,0 +1,113 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import os +import queue as q +import traceback +from multiprocessing import Process, Queue + +import torch.multiprocessing as mp + +from pytorch_lightning.utilities.imports import _XLA_AVAILABLE + +if _XLA_AVAILABLE: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + +#: define waiting time got checking TPU available in sec +TPU_CHECK_TIMEOUT = 25 + + +def inner_f(queue, func, *args, **kwargs): # pragma: no cover + try: + queue.put(func(*args, **kwargs)) + # todo: specify the possible exception + except Exception: + traceback.print_exc() + queue.put(None) + + +def pl_multi_process(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + queue = Queue() + proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) + proc.start() + proc.join(TPU_CHECK_TIMEOUT) + try: + return queue.get_nowait() + except q.Empty: + traceback.print_exc() + return False + + return wrapper + + +class XLADeviceUtils: + """Used to detect the type of XLA device""" + + _TPU_AVAILABLE = False + + @staticmethod + @pl_multi_process + def _is_device_tpu() -> bool: + """ + Check if device is TPU + + Return: + A boolean value indicating if the xla device is a TPU device or not + """ + + def _fn(_: int, mp_queue): + try: + device = xm.xla_device() + mp_queue.put(device.type == 'xla') + except Exception: + mp_queue.put(False) + + smp = mp.get_context("spawn") + queue = smp.SimpleQueue() + xmp.spawn(_fn, args=(queue, ), nprocs=1) + return queue.get() + + @staticmethod + def xla_available() -> bool: + """ + Check if XLA library is installed + + Return: + A boolean value indicating if a XLA is installed + """ + return _XLA_AVAILABLE + + @staticmethod + def tpu_device_exists() -> bool: + """ + Runs XLA device check within a separate process + + Return: + A boolean value indicating if a TPU device exists on the system + """ + if os.getenv("PL_TPU_AVAILABLE", '0') == "1": + XLADeviceUtils._TPU_AVAILABLE = True + + if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: + + XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() + + if XLADeviceUtils._TPU_AVAILABLE: + os.environ["PL_TPU_AVAILABLE"] = '1' + + return XLADeviceUtils._TPU_AVAILABLE diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py new file mode 100644 index 00000000000000..f028222e3930bb --- /dev/null +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.utilities import rank_zero_deprecation + +rank_zero_deprecation( + "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4" +) + +from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401 diff --git a/requirements-extra.txt b/requirements-extra.txt deleted file mode 100644 index bab1f3dd1fa13b..00000000000000 --- a/requirements-extra.txt +++ /dev/null @@ -1,11 +0,0 @@ -# extended list of package dependencies to reach full functionality - -neptune-client>=0.4.109 -comet-ml>=1.0.56 -mlflow>=1.0.0 -test_tube>=0.7.5 -wandb>=0.8.21 -trains>=0.14.1 -matplotlib>=3.1.1 -# no need to install with [pytorch] as pytorch is already installed and torchvision is required only for Horovod examples -horovod>=0.19.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 81441b367e36c8..4649983b79d787 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,12 @@ # the default package dependencies -tqdm>=4.41.0 -numpy>=1.16.4 -torch>=1.1 -tensorboard>=1.14 +numpy>=1.16.6 +torch>=1.4 future>=0.17.1 # required for builtins in setup.py - +# pyyaml>=3.13 +PyYAML>=5.1, !=5.4.* # OmegaConf requirement >=5.1 +tqdm>=4.41.0 +fsspec[http]>=0.8.1 +tensorboard>=2.2.0 +torchmetrics>=0.2.0 +pyDeprecate==0.1.1 \ No newline at end of file diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py new file mode 100644 index 00000000000000..d0dfbc59e23523 --- /dev/null +++ b/requirements/adjust_versions.py @@ -0,0 +1,50 @@ +import os +import re +import sys +from typing import Any, Dict + +VERSIONS_LUT: Dict[str, Dict[str, Any]] = { + "1.4.0": dict(torchvision="0.5.0", torchtext="0.5"), + "1.5.0": dict(torchvision="0.6.0", torchtext="0.6"), + "1.5.1": dict(torchvision="0.6.1", torchtext="0.6"), + "1.6.0": dict(torchvision="0.7.0", torchtext="0.7"), + "1.7.0": dict(torchvision="0.8.1", torchtext="0.8"), + "1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"), + "1.8.0": dict(torchvision="0.9.0", torchtext="0.9"), + "1.8.1": dict(torchvision="0.9.0", torchtext="0.9"), +} + + +def find_latest(ver: str, versions_all: list) -> str: + # drop all except semantic version + ver = re.search(r'([\.\d]+)', ver).groups()[0] + # find candidates, by starting version pattern + options = [v for v in versions_all if v.startswith(ver)] + assert options, f"missing {ver} among {versions_all}" + # take the last one... + return sorted(options)[-1] + + +def main(path_req: str, torch_version: str = None) -> None: + with open(path_req, "r") as fp: + req = fp.read() + + if not torch_version: + import torch + torch_version = torch.__version__ + assert torch_version, f"invalid/missing Torch: {torch_version}" + + torch_version = find_latest(torch_version, list(VERSIONS_LUT.keys())) + dep_versions = VERSIONS_LUT[torch_version] + dep_versions["torch"] = torch_version + for lib in dep_versions: + version = dep_versions[lib] + replace = f"{lib}=={version}\n" + req = re.sub(rf"{lib}[>=]*[\d\.]*{os.linesep}", replace, req) + + with open(path_req, "w") as fp: + fp.write(req) + + +if __name__ == "__main__": + main(*sys.argv[1:]) diff --git a/requirements/devel.txt b/requirements/devel.txt new file mode 100644 index 00000000000000..dcf66495ee46fc --- /dev/null +++ b/requirements/devel.txt @@ -0,0 +1,11 @@ +# install all mandatory dependencies +-r ../requirements.txt + +# install all extra dependencies for full package testing +-r ./extra.txt + +# extended list of dependencies for development and run lint and tests +-r ./test.txt + +# install all extra dependencies for running examples +-r ./examples.txt diff --git a/requirements/docs.txt b/requirements/docs.txt new file mode 100644 index 00000000000000..ceabc12d69c36b --- /dev/null +++ b/requirements/docs.txt @@ -0,0 +1,13 @@ +sphinx>=3.0, !=3.5 # fails with sphinx.ext.viewcode +recommonmark # fails with badges +m2r # fails with multi-line text +nbsphinx>=0.8 +pandoc>=1.0 +docutils>=0.16 +sphinxcontrib-fulltoc>=1.0 +sphinxcontrib-mockautodoc +https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#egg=pt-lightning-sphinx-theme +sphinx-autodoc-typehints>=1.0 +sphinx-paramlinks>=0.4.0 +sphinx-togglebutton>=0.2 +sphinx-copybutton>=0.3 \ No newline at end of file diff --git a/requirements/examples.txt b/requirements/examples.txt new file mode 100644 index 00000000000000..83ceafe3c2934c --- /dev/null +++ b/requirements/examples.txt @@ -0,0 +1,2 @@ +torchvision>=0.5 +gym>=0.17.0 diff --git a/requirements/extra.txt b/requirements/extra.txt new file mode 100644 index 00000000000000..715916c4e36acb --- /dev/null +++ b/requirements/extra.txt @@ -0,0 +1,11 @@ +# extended list of package dependencies to reach full functionality + +matplotlib>3.1 +horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed +omegaconf>=2.0.1 +torchtext>=0.5 +# onnx>=1.7.0 +onnxruntime>=1.3.0 +hydra-core>=1.0 +# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs +https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip diff --git a/tests/install_AMP.sh b/requirements/install_Apex.sh similarity index 61% rename from tests/install_AMP.sh rename to requirements/install_Apex.sh index 2c56bb25b742b8..0c70e0bc348708 100644 --- a/tests/install_AMP.sh +++ b/requirements/install_Apex.sh @@ -4,6 +4,7 @@ ROOT=$PWD git clone https://github.com/NVIDIA/apex cd apex pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ -pip install -v --no-cache-dir ./ +# If build with extensions fails, you can run this line to build without extensions +# pip install -v --no-cache-dir ./ cd $ROOT rm -rf apex diff --git a/requirements/install_ONNX.sh b/requirements/install_ONNX.sh new file mode 100644 index 00000000000000..d6784fa373d6fd --- /dev/null +++ b/requirements/install_ONNX.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +ROOT=$PWD + +# python -m pip install protobuf +# git clone --recursive https://github.com/onnx/onnx.git +# cd onnx +# python setup.py bdist_wheel +# pip install --upgrade dist/*.whl +# cd $ROOT +# rm -rf onnx + + +# https://github.com/microsoft/onnxruntime/blob/master/BUILD.md +git clone --recursive https://github.com/Microsoft/onnxruntime +cd onnxruntime +export ONNX_ML=1 +pip install setuptools wheel numpy + +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + ./build.sh --config RelWithDebInfo --build_shared_lib --build_wheel --parallel +elif [[ "$OSTYPE" == "darwin"* ]]; then + # Mac OSX + ./build.sh --config RelWithDebInfo --build_shared_lib --build_wheel --parallel --use_xcode +elif [[ "$OSTYPE" == "cygwin" ]]; then + # POSIX compatibility layer and Linux environment emulation for Windows + ./build.sh --config RelWithDebInfo --build_shared_lib --build_wheel --parallel +elif [[ "$OSTYPE" == "msys" ]]; then + # Lightweight shell and GNU utilities compiled for Windows (part of MinGW) + .\build.bat --config RelWithDebInfo --build_shared_lib --build_wheel --parallel +elif [[ "$OSTYPE" == "win32" ]]; then + .\build.bat --config RelWithDebInfo --build_shared_lib --build_wheel --parallel +else + echo $OSTYPE # Unknown. +fi + +find . -name "*.whl" +pip install --upgrade $(find . -name "*.whl") + +cd $ROOT +rm -rf onnxruntime diff --git a/requirements/loggers.txt b/requirements/loggers.txt new file mode 100644 index 00000000000000..001210855871dd --- /dev/null +++ b/requirements/loggers.txt @@ -0,0 +1,6 @@ +# all supported loggers +neptune-client>=0.4.109 +comet-ml>=3.1.12 +mlflow>=1.0.0 +test_tube>=0.7.5 +wandb>=0.8.21 diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 00000000000000..3c81479e148c25 --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,17 @@ +coverage>5.2.0 +codecov>=2.1 +pytest>=6.0 +#pytest-cov>2.10 +#pytest-xdist +flake8>=3.6 +check-manifest +twine==3.2 +isort>=5.6.4 +mypy>=0.720, <0.800 +pre-commit>=1.0 + +cloudpickle>=1.3 +scikit-learn>0.22.1 +scikit-image>0.17.1 +nltk>=3.3 +pandas # needed in benchmarks diff --git a/setup.cfg b/setup.cfg index aab7a580c77b91..6365482e32aa8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + [tool:pytest] norecursedirs = .git @@ -6,12 +20,17 @@ norecursedirs = python_files = test_*.py # doctest_plus = disabled -addopts = --strict +addopts = + --strict + --doctest-modules + --color=yes markers = slow remote_data filterwarnings gpus_param_tests +junit_duration_report = call + [coverage:report] exclude_lines = @@ -19,21 +38,35 @@ exclude_lines = warnings pass rank_zero_warn + raise NotImplementedError +# TODO: figure out how to get codecov to pick up the test results on these backends +# The actual coverage for each is 90%+ +# *metrics (94%+) are temporarily removed from testing while tests speed up +omit = + pytorch_lightning/cluster_environments/*.py + pytorch_lightning/utilities/xla_device_utils.py + pytorch_lightning/utilities/distributed.py + pytorch_lightning/tuner/auto_gpu_select.py + [flake8] # TODO: this should be 88 or 100 according PEP8 max-line-length = 120 -exclude = .tox,*.egg,build,temp +exclude = + .tox, + *.egg + build + temp + select = E,W,F doctests = True verbose = 2 # https://pep8.readthedocs.io/en/latest/intro.html#error-codes format = pylint ignore = - E731 - W504 - F401 - F841 + E731 # Ignore "Do not assign a lambda expression, use a def" + W503 # Ignore "Line break occurred before a binary operator" + # setup.cfg or tox.ini [check-manifest] @@ -43,14 +76,114 @@ ignore = .github/* .circleci + [metadata] license_file = LICENSE # long_description = file:README.md # long_description_content_type = text/markdown + [pydocstyle] convention = pep257 # D104, D107: Ignore missing docstrings in __init__ files and methods. # D202: Ignore a blank line after docstring (collision with Python Black in decorators) add-ignore = D104,D107,D202 max-line-length = 120 + + +[yapf] +based_on_style = pep8 +spaces_before_comment = 2 +split_before_logical_operator = true +split_before_arithmetic_operator = true +COLUMN_LIMIT = 120 +COALESCE_BRACKETS = true +DEDENT_CLOSING_BRACKETS = true +ALLOW_SPLIT_BEFORE_DICT_VALUE = false +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true +NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false + + +[mypy] +# Typing tests is low priority, but enabling type checking on the +# untyped test functions (using `--check-untyped-defs`) is still +# high-value because it helps test the typing. +files = pytorch_lightning, pl_examples, benchmarks, tests +disallow_untyped_defs = True +ignore_missing_imports = True +show_error_codes = True +warn_redundant_casts = True +warn_unused_configs = True +warn_unused_ignores = True + +# todo: this is magically failing, need to be revisited +[mypy-pytorch_lightning.accelerators.tpu.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.callbacks.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.core.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.loggers.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.metrics.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.overrides.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.plugins.environments.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.plugins.training_type.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.profiler.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.pt_overrides.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.root_module.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.trainer.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.distributed.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.tuner.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pytorch_lightning.utilities.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-pl_examples.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-benchmarks.*] +ignore_errors = True + +# todo: add proper typing to this module... +[mypy-tests.*] +ignore_errors = True diff --git a/setup.py b/setup.py index dfcaf29a5990e9..0979de8ce90650 100755 --- a/setup.py +++ b/setup.py @@ -1,47 +1,65 @@ #!/usr/bin/env python +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os -from io import open + # Always prefer setuptools over distutils -from setuptools import setup, find_packages +import sys + +from setuptools import find_packages, setup try: - import builtins + from pytorch_lightning import info, setup_tools except ImportError: - import __builtin__ as builtins + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append("pytorch_lightning") + import info + import setup_tools # https://packaging.python.org/guides/single-sourcing-package-version/ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ +_PATH_ROOT = os.path.dirname(__file__) +_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') -PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTNING_SETUP__ = True - -import pytorch_lightning # noqa: E402 - - -def load_requirements(path_dir=PATH_ROOT, comment_char='#'): - with open(os.path.join(path_dir, 'requirements.txt'), 'r') as file: - lines = [ln.strip() for ln in file.readlines()] - reqs = [] - for ln in lines: - # filer all comments - if comment_char in ln: - ln = ln[:ln.index(comment_char)] - if ln: # if requirement is not empty - reqs.append(ln) - return reqs - +# https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras +# Define package extras. These are only installed if you specify them. +# From remote, use like `pip install pytorch-lightning[dev, docs]` +# From local copy of repo, use like `pip install ".[dev, docs]"` +extras = { + # 'docs': load_requirements(file_name='docs.txt'), + 'examples': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='examples.txt'), + 'loggers': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='loggers.txt'), + 'extra': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='extra.txt'), + 'test': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='test.txt') +} +extras['dev'] = extras['extra'] + extras['loggers'] + extras['test'] +extras['all'] = extras['dev'] + extras['examples'] # + extras['docs'] -def load_long_describtion(): - # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png - url = os.path.join(pytorch_lightning.__homepage__, 'raw', pytorch_lightning.__version__, 'docs') - text = open('README.md', encoding='utf-8').read() - # replace relative repository path to absolute link to the release - text = text.replace('](docs', f']({url}') - # SVG images are not readable on PyPI, so replace them with PNG - text = text.replace('.svg', '.png') - return text +# These packages shall be installed only on GPU machines +PACKAGES_GPU_ONLY = ['horovod'] +# create a version for CPU machines +for ex in ('cpu', 'cpu-extra'): + kw = ex.split('-')[1] if '-' in ex else 'all' + # filter cpu only packages + extras[ex] = [pkg for pkg in extras[kw] if not any(pgpu.lower() in pkg.lower() for pgpu in PACKAGES_GPU_ONLY)] +long_description = setup_tools._load_readme_description( + _PATH_ROOT, + homepage=info.__homepage__, + version=info.__version__, +) # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious @@ -49,32 +67,29 @@ def load_long_describtion(): # the goal of the project is simplicity for researchers, don't want to add too much # engineer specific practices setup( - name='pytorch-lightning', - version=pytorch_lightning.__version__, - description=pytorch_lightning.__docs__, - author=pytorch_lightning.__author__, - author_email=pytorch_lightning.__author_email__, - url=pytorch_lightning.__homepage__, + name="pytorch-lightning", + version=info.__version__, + description=info.__docs__, + author=info.__author__, + author_email=info.__author_email__, + url=info.__homepage__, download_url='https://github.com/PyTorchLightning/pytorch-lightning', - license=pytorch_lightning.__license__, - packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks']), - - long_description=load_long_describtion(), + license=info.__license__, + packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']), + long_description=long_description, long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, - keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=load_requirements(PATH_ROOT), - + install_requires=setup_tools._load_requirements(_PATH_ROOT), + extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", "Documentation": "https://pytorch-lightning.rtfd.io/en/latest/", "Source Code": "https://github.com/PyTorchLightning/pytorch-lightning", }, - classifiers=[ 'Environment :: Console', 'Natural Language :: English', @@ -87,12 +102,14 @@ def load_long_describtion(): 'Topic :: Scientific/Engineering :: Image Recognition', 'Topic :: Scientific/Engineering :: Information Analysis', # Pick your license as you wish - 'License :: OSI Approved :: BSD License', + 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', ], ) diff --git a/tests/Dockerfile b/tests/Dockerfile deleted file mode 100644 index 876d6fbf542c3f..00000000000000 --- a/tests/Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -ARG TORCH_VERSION=1.4 -ARG CUDA_VERSION=10.1 - -FROM pytorch/pytorch:${TORCH_VERSION}-cuda${CUDA_VERSION}-cudnn7-devel - -ENV HOROVOD_GPU_ALLREDUCE: NCCL -ENV HOROVOD_GPU_BROADCAST: NCCL -ENV HOROVOD_WITH_PYTORCH: 1 -ENV HOROVOD_WITHOUT_TENSORFLOW: 1 -ENV HOROVOD_WITHOUT_MXNET: 1 -ENV HOROVOD_WITH_GLOO: 1 -ENV HOROVOD_WITHOUT_MPI: 1 -ENV PATH: "$PATH:/root/.local/bin" -ENV MAKEFLAGS: "-j$(nproc)" - -COPY ./tests/install_AMP.sh install_AMP.sh -COPY ./requirements.txt requirements.txt -COPY ./requirements-extra.txt requirements-extra.txt -COPY ./tests/requirements.txt requirements-tests.txt - -# Install AMP -RUN apt-get update && apt-get install -y cmake && \ - bash install_AMP.sh && \ - pip install -r requirements.txt --user && \ - pip install -r requirements-extra.txt --user && \ - pip install -r requirements-tests.txt --user && \ - pip list diff --git a/tests/README.md b/tests/README.md index c931a2a02e935f..0b0563a3ae540f 100644 --- a/tests/README.md +++ b/tests/README.md @@ -9,15 +9,18 @@ run on a 2-GPU machine to validate the full test-suite. To run all tests do the following: + +Install [Open MPI](https://www.open-mpi.org/) or another MPI implementation. Learn how to install Open MPI [on this page](https://www.open-mpi.org/faq/?category=building#easy-build>). + ```bash git clone https://github.com/PyTorchLightning/pytorch-lightning cd pytorch-lightning # install AMP support -bash tests/install_AMP.sh +bash requirements/install_Apex.sh # install dev deps -pip install -r tests/requirements-devel.txt +pip install -r requirements/devel.txt # run tests py.test -v @@ -27,17 +30,17 @@ To test models that require GPU make sure to run the above command on a GPU mach The GPU machine must have: 1. At least 2 GPUs. 2. [NVIDIA-apex](https://github.com/NVIDIA/apex#linux) installed. -3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL pip install horovod` +3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_OPERATIONS=NCCL pip install horovod` -## Running Coverage -Make sure to run coverage on a GPU machine with at least 2 GPUs and NVIDIA apex installed. +## Running Coverage +Make sure to run coverage on a GPU machine with at least 2 GPUs and NVIDIA apex installed. ```bash cd pytorch-lightning -# generate coverage (coverage is also installed as part of dev dependencies under tests/requirements-devel.txt) -coverage run --source pytorch_lightning -m py.test pytorch_lightning tests examples -v --doctest-modules +# generate coverage (coverage is also installed as part of dev dependencies under requirements/devel.txt) +coverage run --source pytorch_lightning -m py.test pytorch_lightning tests examples -v # print coverage stats coverage report -m @@ -51,11 +54,11 @@ coverage xml You can build it on your own, note it takes lots of time, be prepared. ```bash git clone -docker image build -t pytorch_lightning:devel-pt_1_4 -f tests/Dockerfile --build-arg TORCH_VERSION=1.4 . +docker image build -t pytorch_lightning:devel-torch1.4 -f dockers/cuda-extras/Dockerfile --build-arg TORCH_VERSION=1.4 . ``` To build other versions, select different Dockerfile. ```bash docker image list -docker run --rm -it pytorch_lightning:devel-pt_1_4 bash -docker image rm pytorch_lightning:devel-pt_1_4 +docker run --rm -it pytorch_lightning:devel-torch1.4 bash +docker image rm pytorch_lightning:devel-torch1.4 ``` diff --git a/tests/__init__.py b/tests/__init__.py index acc27596f9c443..fc634e6b73fec3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,18 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging import os import numpy as np -import torch -TEST_ROOT = os.path.dirname(__file__) -PACKAGE_ROOT = os.path.dirname(TEST_ROOT) -TEMP_PATH = os.path.join(PACKAGE_ROOT, 'test_temp') +_TEST_ROOT = os.path.dirname(__file__) +_PROJECT_ROOT = os.path.dirname(_TEST_ROOT) +_TEMP_PATH = os.path.join(_PROJECT_ROOT, 'test_temp') +PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') +PATH_LEGACY = os.path.join(_PROJECT_ROOT, 'legacy') + +# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages +if _PROJECT_ROOT not in os.getenv('PYTHONPATH', ""): + splitter = ":" if os.environ.get("PYTHONPATH", "") else "" + os.environ['PYTHONPATH'] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}' # generate a list of random seeds for each test RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) -ROOT_SEED = 1234 -torch.manual_seed(ROOT_SEED) -np.random.seed(ROOT_SEED) -RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000)) -if not os.path.isdir(TEMP_PATH): - os.mkdir(TEMP_PATH) +if not os.path.isdir(_TEMP_PATH): + os.mkdir(_TEMP_PATH) + +logging.basicConfig(level=logging.ERROR) diff --git a/tests/accelerators/__init__.py b/tests/accelerators/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/accelerators/ddp_model.py b/tests/accelerators/ddp_model.py new file mode 100644 index 00000000000000..78d1306665c59d --- /dev/null +++ b/tests/accelerators/ddp_model.py @@ -0,0 +1,64 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Runs either `.fit()` or `.test()` on a single node across multiple gpus. +""" +import os +from argparse import ArgumentParser + +import torch + +from pytorch_lightning import seed_everything, Trainer +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.simple_models import ClassificationModel + + +def main(): + seed_everything(1234) + + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser) + parser.add_argument('--trainer_method', default='fit') + parser.add_argument('--tmpdir') + parser.add_argument('--workdir') + parser.set_defaults(gpus=2) + parser.set_defaults(accelerator="ddp") + args = parser.parse_args() + + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer.from_argparse_args(args) + + if args.trainer_method == 'fit': + trainer.fit(model, datamodule=dm) + result = None + elif args.trainer_method == 'test': + result = trainer.test(model, datamodule=dm) + elif args.trainer_method == 'fit_test': + trainer.fit(model, datamodule=dm) + result = trainer.test(model, datamodule=dm) + else: + raise ValueError(f'Unsupported: {args.trainer_method}') + + result_ext = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': result, + } + file_path = os.path.join(args.tmpdir, 'ddp.result') + torch.save(result_ext, file_path) + + +if __name__ == '__main__': + main() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py new file mode 100644 index 00000000000000..79a17df074e35b --- /dev/null +++ b/tests/accelerators/test_accelerator_connector.py @@ -0,0 +1,452 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import os +from typing import Optional +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins import ( + DDP2Plugin, + DDPPlugin, + DDPShardedPlugin, + DDPSpawnPlugin, + DDPSpawnShardedPlugin, + DeepSpeedPlugin, + ParallelPlugin, + PrecisionPlugin, + SingleDevicePlugin, +) +from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +def test_accelerator_choice_cpu(tmpdir): + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, SingleDevicePlugin) + + +def test_accelerator_choice_ddp_cpu(tmpdir): + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_cpu', + ) + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_accelerator_choice_ddp(cuda_available_mock, device_count_mock): + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp', + gpus=1, + ) + assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_spawn', + gpus=1, + ) + assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + + +@RunIf(min_gpus=2) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "SLURM_LOCALID": "10" + } +) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_slurm(setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 + assert trainer.training_type_plugin.task_idx == 10 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp', + gpus=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(min_gpus=1) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "10" + } +) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp2 + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDP2Plugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 + assert trainer.training_type_plugin.task_idx == 10 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp2', + gpus=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(min_gpus=1) +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp + assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 + assert trainer.training_type_plugin.task_idx == 10 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp', + gpus=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(min_gpus=1) +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp2 + assert isinstance(trainer.accelerator, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDP2Plugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 + assert trainer.training_type_plugin.task_idx == 10 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp2', + gpus=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict(os.environ, { + "WORLD_SIZE": "1", + "LOCAL_RANK": "10", + "NODE_RANK": "0", +}) +@mock.patch('torch.cuda.device_count', return_value=0) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 + assert trainer.training_type_plugin.task_idx == 10 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_cpu', + num_processes=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict( + os.environ, { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" + } +) +@mock.patch('torch.cuda.device_count', return_value=0) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.task_idx == 0 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_cpu', + num_processes=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict( + os.environ, { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" + } +) +@mock.patch('torch.cuda.device_count', return_value=0) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock): + """ + Test that we choose the custom cluster even when SLURM or TE flags are around + """ + + class CustomCluster(LightningEnvironment): + + def master_address(self): + return 'asdf' + + def creates_children(self) -> bool: + return True + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert trainer.use_ddp + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, CustomCluster) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + plugins=[CustomCluster()], + fast_dev_run=True, + accelerator='ddp_cpu', + num_processes=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict( + os.environ, { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" + } +) +@mock.patch('torch.cuda.device_count', return_value=0) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_custom_accelerator(device_count_mock, setup_distributed_mock): + + class Accel(Accelerator): + pass + + class Prec(PrecisionPlugin): + pass + + class TrainTypePlugin(SingleDevicePlugin): + pass + + accelerator = Accel( + training_type_plugin=TrainTypePlugin(device=torch.device("cpu")), + precision_plugin=Prec(), + ) + trainer = Trainer( + accelerator=accelerator, + fast_dev_run=True, + num_processes=2, + ) + assert isinstance(trainer.accelerator, Accel) + assert isinstance(trainer.training_type_plugin, TrainTypePlugin) + assert isinstance(trainer.precision_plugin, Prec) + + +@mock.patch.dict( + os.environ, { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" + } +) +@mock.patch('torch.cuda.device_count', return_value=0) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert trainer.training_type_plugin.task_idx == 0 + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_cpu', + num_processes=2, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_ipython_incompatible_backend_error(*_): + with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): + Trainer(accelerator="ddp", gpus=2) + + with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): + Trainer(accelerator="ddp_cpu", num_processes=2) + + with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"): + Trainer(accelerator="ddp2", gpus=2) + + +@pytest.mark.parametrize( + ["accelerator", "plugin"], + [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')], +) +def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str): + """Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.""" + trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + + trainer = Trainer(plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + + +@pytest.mark.parametrize(["accelerator", "plugin"], [ + ('ddp', DDPPlugin), + ('ddp_spawn', DDPSpawnPlugin), + ('ddp_sharded', DDPShardedPlugin), + ('ddp_sharded_spawn', DDPSpawnShardedPlugin), + pytest.param('deepspeed', DeepSpeedPlugin, marks=RunIf(deepspeed=True)), +]) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_accelerator_choice_multi_node_gpu( + mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin +): + trainer = Trainer( + accelerator=accelerator, + default_root_dir=tmpdir, + num_nodes=2, + gpus=2, + ) + assert isinstance(trainer.training_type_plugin, plugin) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py new file mode 100644 index 00000000000000..2ad151d75e76c3 --- /dev/null +++ b/tests/accelerators/test_common.py @@ -0,0 +1,151 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import SingleDevicePlugin +from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.boring_model import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize( + "trainer_kwargs", ( + pytest.param(dict(gpus=1), marks=RunIf(min_gpus=1)), + pytest.param(dict(accelerator="dp", gpus=2), marks=RunIf(min_gpus=2)), + pytest.param(dict(accelerator="ddp_spawn", gpus=2), marks=RunIf(min_gpus=2)), + ) +) +def test_evaluate(tmpdir, trainer_kwargs): + tutils.set_random_master_port() + + dm = ClassifDataModule() + model = CustomClassificationModelDP() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + deterministic=True, + **trainer_kwargs + ) + + result = trainer.fit(model, datamodule=dm) + assert result + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + old_weights = model.layer_0.weight.clone().detach().cpu() + + result = trainer.validate(datamodule=dm) + assert result[0]['val_acc'] > 0.55 + + result = trainer.test(datamodule=dm) + assert result[0]['test_acc'] > 0.55 + + # make sure weights didn't change + new_weights = model.layer_0.weight.clone().detach().cpu() + torch.testing.assert_allclose(old_weights, new_weights) + + +def test_model_parallel_setup_called(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.configure_sharded_model_called = False + self.layer = None + + def configure_sharded_model(self): + self.configure_sharded_model_called = True + self.layer = torch.nn.Linear(32, 2) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.configure_sharded_model_called + + +class DummyModel(BoringModel): + + def __init__(self): + super().__init__() + self.configure_sharded_model_called = False + + def configure_sharded_model(self): + self.configure_sharded_model_called = True + + +def test_configure_sharded_model_false(tmpdir): + """Ensure ``configure_sharded_model`` is not called, when turned off""" + + class CustomPlugin(SingleDevicePlugin): + + @property + def call_configure_sharded_model_hook(self) -> bool: + return False + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) + + assert not model.configure_sharded_model_called + + +def test_accelerator_configure_sharded_model_called_once(tmpdir): + """Ensure that the configure sharded model hook is called, and set to False after to ensure not called again.""" + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + assert trainer.accelerator.call_configure_sharded_model_hook is True + trainer.fit(model) + assert trainer.accelerator.call_configure_sharded_model_hook is False + + +def test_configure_sharded_model_called_once(tmpdir): + """Ensure ``configure_sharded_model`` is only called once""" + + model = DummyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.configure_sharded_model_called + model.configure_sharded_model_called = False + + assert not model.configure_sharded_model_called diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py new file mode 100644 index 00000000000000..46379a9d10c14f --- /dev/null +++ b/tests/accelerators/test_cpu.py @@ -0,0 +1,52 @@ +from unittest.mock import Mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import CPUAccelerator +from pytorch_lightning.plugins import SingleDevicePlugin +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel + + +def test_unsupported_precision_plugins(): + """ Test error messages are raised for unsupported precision plugins with CPU. """ + trainer = Mock() + model = Mock() + accelerator = CPUAccelerator( + training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin() + ) + with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): + accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till ``pre_dispatch``. + """ + + class TestModel(BoringModel): + + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) + trainer.fit(model) diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py new file mode 100644 index 00000000000000..06aed2c1020fff --- /dev/null +++ b/tests/accelerators/test_ddp.py @@ -0,0 +1,119 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional +from unittest import mock +from unittest.mock import patch + +import pytest +import torch + +from pytorch_lightning import Trainer +from tests.accelerators import ddp_model +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf +from tests.utilities.distributed import call_training_script + +CLI_ARGS = '--max_epochs 1 --gpus 2 --accelerator ddp' + + +@RunIf(min_gpus=2) +def test_multi_gpu_model_ddp_fit_only(tmpdir): + # call the script + call_training_script(ddp_model, CLI_ARGS, 'fit', tmpdir, timeout=120) + + # load the results of the script + result_path = os.path.join(tmpdir, 'ddp.result') + result = torch.load(result_path) + + # verify the file wrote the expected outputs + assert result['status'] == 'complete' + + +@RunIf(min_gpus=2) +def test_multi_gpu_model_ddp_test_only(tmpdir): + # call the script + call_training_script(ddp_model, CLI_ARGS, 'test', tmpdir) + + # load the results of the script + result_path = os.path.join(tmpdir, 'ddp.result') + result = torch.load(result_path) + + # verify the file wrote the expected outputs + assert result['status'] == 'complete' + + +@RunIf(min_gpus=2) +def test_multi_gpu_model_ddp_fit_test(tmpdir): + # call the script + call_training_script(ddp_model, CLI_ARGS, 'fit_test', tmpdir, timeout=20) + + # load the results of the script + result_path = os.path.join(tmpdir, 'ddp.result') + result = torch.load(result_path) + + # verify the file wrote the expected outputs + assert result['status'] == 'complete' + + model_outs = result['result'] + for out in model_outs: + assert out['test_acc'] > 0.7 + + +@RunIf(skip_windows=True) +@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") +def test_torch_distributed_backend_env_variables(tmpdir): + """ + This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError. + """ + _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} + with patch.dict(os.environ, _environ), \ + patch('torch.cuda.device_count', return_value=2): + with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ddp", + gpus=2, + logger=False, + ) + trainer.fit(model) + + +@RunIf(skip_windows=True) +@mock.patch('torch.cuda.device_count', return_value=1) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.set_device') +@mock.patch.dict(os.environ, {'PL_TORCH_DISTRIBUTED_BACKEND': 'gloo'}, clear=True) +def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): + """ + Test to ensure torch distributed is available within the setup hook using ddp + """ + + class TestModel(BoringModel): + + def setup(self, stage: Optional[str] = None) -> None: + assert torch.distributed.is_initialized() + raise SystemExit() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ddp", + gpus=1, + ) + with pytest.raises(SystemExit): + trainer.fit(model) diff --git a/tests/accelerators/test_ddp_spawn.py b/tests/accelerators/test_ddp_spawn.py new file mode 100644 index 00000000000000..2bbcaa2e97cf39 --- /dev/null +++ b/tests/accelerators/test_ddp_spawn.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core import memory +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + + +@RunIf(min_gpus=2) +def test_multi_gpu_early_stop_ddp_spawn(tmpdir): + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + callbacks=[EarlyStopping(monitor='train_acc')], + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + accelerator='ddp_spawn', + ) + + dm = ClassifDataModule() + model = ClassificationModel() + tpipes.run_model_test(trainer_options, model, dm) + + +@RunIf(min_gpus=2) +def test_multi_gpu_model_ddp_spawn(tmpdir): + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + accelerator='ddp_spawn', + progress_bar_refresh_rate=0, + ) + + model = BoringModel() + + tpipes.run_model_test(trainer_options, model) + + # test memory helper functions + memory.get_memory_profile('min_max') + + +@RunIf(min_gpus=2) +def test_ddp_all_dataloaders_passed_to_fit(tmpdir): + """Make sure DDP works with dataloaders passed to fit()""" + tutils.set_random_master_port() + + model = BoringModel() + fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader()) + + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.2, + limit_val_batches=0.2, + gpus=[0, 1], + accelerator='ddp_spawn', + ) + trainer.fit(model, **fit_options) + assert trainer.state == TrainerState.FINISHED, "DDP doesn't work with dataloaders passed to fit()." diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py new file mode 100644 index 00000000000000..ab46aba3119fb7 --- /dev/null +++ b/tests/accelerators/test_dp.py @@ -0,0 +1,205 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +import pytorch_lightning as pl +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core import memory +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + + +class CustomClassificationModelDP(ClassificationModel): + + def _step(self, batch, batch_idx): + x, y = batch + logits = self(x) + return {'logits': logits, 'y': y} + + def training_step(self, batch, batch_idx): + out = self._step(batch, batch_idx) + loss = F.cross_entropy(out['logits'], out['y']) + return loss + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def validation_step_end(self, outputs): + self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y'])) + + def test_step_end(self, outputs): + self.log('test_acc', self.test_acc(outputs['logits'], outputs['y'])) + + +@RunIf(min_gpus=2) +def test_multi_gpu_early_stop_dp(tmpdir): + """Make sure DDP works. with early stopping""" + tutils.set_random_master_port() + + dm = ClassifDataModule() + model = CustomClassificationModelDP() + + trainer_options = dict( + default_root_dir=tmpdir, + callbacks=[EarlyStopping(monitor='val_acc')], + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + accelerator='dp', + ) + + tpipes.run_model_test(trainer_options, model, dm) + + +@RunIf(min_gpus=2) +def test_multi_gpu_model_dp(tmpdir): + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + accelerator='dp', + progress_bar_refresh_rate=0, + ) + + model = BoringModel() + + tpipes.run_model_test(trainer_options, model) + + # test memory helper functions + memory.get_memory_profile('min_max') + + +class ReductionTestModel(BoringModel): + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def add_outputs(self, output, device): + output.update({ + "reduce_int": torch.tensor(device.index, dtype=torch.int, device=device), + "reduce_float": torch.tensor(device.index, dtype=torch.float, device=device), + }) + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + self.add_outputs(output, batch.device) + return output + + def validation_step(self, batch, batch_idx): + output = super().validation_step(batch, batch_idx) + self.add_outputs(output, batch.device) + return output + + def test_step(self, batch, batch_idx): + output = super().test_step(batch, batch_idx) + self.add_outputs(output, batch.device) + return output + + def training_epoch_end(self, outputs): + assert outputs[0]["loss"].shape == torch.Size([]) + assert outputs[0]["reduce_int"].item() == 0 # mean([0, 1]) = 0 + assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5 + + +def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): + """ + Test that an exception is raised when overriding batch_transfer_hooks in DP model. + """ + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + + class CustomModel(BoringModel): + + def transfer_batch_to_device(self, batch, device): + batch = batch.to(device) + return batch + + trainer_options = dict( + default_root_dir=tmpdir, + max_steps=7, + gpus=[0, 1], + accelerator='dp', + ) + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_before_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_after_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'): + trainer.fit(model) + + +@RunIf(min_gpus=2) +def test_dp_training_step_dict(tmpdir): + """ This test verifies that dp properly reduces dictionaries """ + model = ReductionTestModel() + model.training_step_end = None + model.validation_step_end = None + model.test_step_end = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + gpus=2, + accelerator='dp', + ) + trainer.fit(model) diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py new file mode 100644 index 00000000000000..c086150a605280 --- /dev/null +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -0,0 +1,137 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from unittest import mock + +import pytest +import torch + +from tests.helpers.runif import RunIf + +ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") +sys.path.insert(0, ROOT) +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) + +from pytorch_lightning import LightningModule # noqa: E402 +from pytorch_lightning import Trainer # noqa: E402 +from tests.helpers.boring_model import BoringModel # noqa: E402 + + +# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml) +# use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)` +@pytest.mark.skip("Multi-node testing is currently disabled") +@RunIf(special=True) +def test_logging_sync_dist_true_ddp(tmpdir): + """ + Tests to ensure that the sync_dist flag works with CPU (should just return the original value) + """ + fake_result = 1 + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True) + return acc + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True) + return {"x": loss} + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + weights_summary=None, + accelerator="ddp", + gpus=1, + num_nodes=2, + ) + trainer.fit(model) + + assert trainer.logged_metrics['foo'] == fake_result + assert trainer.logged_metrics['bar'] == fake_result + + +# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml) +# use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)` +@pytest.mark.skip("Multi-node testing is currently disabled") +@RunIf(special=True) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__validation_step__log(tmpdir): + """ + Tests that validation_step can log + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch) + acc = acc + batch_idx + self.log('a', acc, on_step=True, on_epoch=True) + self.log('a2', 2) + + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + acc = self.step(batch) + acc = acc + batch_idx + self.log('b', acc, on_step=True, on_epoch=True) + self.training_step_called = True + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + accelerator="ddp", + gpus=1, + num_nodes=2, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + expected_logged_metrics = { + 'a2', + 'a_step', + 'a_epoch', + 'b_step/epoch_0', + 'b_step/epoch_1', + 'b_epoch', + 'epoch', + } + logged_metrics = set(trainer.logged_metrics.keys()) + assert expected_logged_metrics == logged_metrics + + # we don't want to enable val metrics during steps because it is not something that users should do + # on purpose DO NOT allow step_b... it's silly to monitor val step metrics + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} + assert expected_cb_metrics == callback_metrics diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py new file mode 100644 index 00000000000000..2104387643b33b --- /dev/null +++ b/tests/accelerators/test_tpu_backend.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +import torch +from torch import nn + +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf +from tests.helpers.utils import pl_multi_process_test + + +class WeightSharingModule(BoringModel): + + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_resume_training_on_cpu(tmpdir): + """ Checks if training can be resumed from a saved checkpoint on CPU""" + # Train a model on TPU + model = BoringModel() + trainer = Trainer( + checkpoint_callback=True, + max_epochs=1, + tpu_cores=8, + ) + trainer.fit(model) + + model_path = trainer.checkpoint_callback.best_model_path + + # Verify saved Tensors are on CPU + ckpt = torch.load(model_path) + weight_tensor = list(ckpt["state_dict"].values())[0] + assert weight_tensor.device == torch.device("cpu") + + # Verify that training is resumed on CPU + trainer = Trainer( + resume_from_checkpoint=model_path, + checkpoint_callback=True, + max_epochs=1, + default_root_dir=tmpdir, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_test_works_after_train(tmpdir): + """ Ensure that .test() works after .fit() """ + + # Train a model on TPU + model = BoringModel() + trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + assert len(trainer.test(model)) == 1 + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_weight_tying_warning(tmpdir, capsys=None): + """ + Ensure a warning is thrown if model parameter lengths do not match + post moving to device. + """ + + model = WeightSharingModule() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + + with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'): + result = trainer.fit(model) + assert result + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_weights_tied(tmpdir, capsys=None): + """ + Test if weights are properly tied on `on_post_move_to_device`. + Ensure no warning for parameter mismatch is thrown. + """ + + class Model(WeightSharingModule): + + def on_post_move_to_device(self): + self.layer_3.weight = self.layer_1.weight + + model = Model() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + + with pytest.warns(UserWarning) as warnings: + result = trainer.fit(model) + assert result + + assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) + assert len(trainer.test(model)) == 1 diff --git a/tests/base/__init__.py b/tests/base/__init__.py index fce8a8fa662077..e099cd110ed43f 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -1,61 +1,3 @@ """Models for testing.""" -import torch - -from tests.base.eval_model_template import EvalModelTemplate -from tests.base.mixins import ( - LightEmptyTestStep, - LightValidationStepMixin, - LightValidationMixin, - LightValidationStepMultipleDataloadersMixin, - LightValidationMultipleDataloadersMixin, - LightTestStepMixin, - LightTestMixin, - LightTestStepMultipleDataloadersMixin, - LightTestMultipleDataloadersMixin, - LightTestFitSingleTestDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - LightValStepFitSingleDataloaderMixin, - LightValStepFitMultipleDataloadersMixin, - LightTrainDataloader, - LightValidationDataloader, - LightTestDataloader, - LightInfTrainDataloader, - LightInfValDataloader, - LightInfTestDataloader, - LightTestOptimizerWithSchedulingMixin, - LightTestMultipleOptimizersWithSchedulingMixin, - LightTestOptimizersWithMixedSchedulingMixin, - LightTestReduceLROnPlateauMixin, - LightTestNoneOptimizerMixin, - LightZeroLenDataloader -) -from tests.base.models import TestModelBase, DictHparamsModel - - -class LightningTestModel(LightTrainDataloader, - LightValidationMixin, - LightTestMixin, - TestModelBase): - """Most common test case. Validation and test dataloaders.""" - - def on_training_metrics(self, logs): - logs['some_tensor_to_test'] = torch.rand(1) - - -class LightningTestModelWithoutHyperparametersArg(LightningTestModel): - """Without hparams argument in constructor """ - - def __init__(self): - import tests.base.utils as tutils - - # the user loads the hparams in some other way - hparams = tutils.get_default_hparams() - super().__init__(hparams) - - -class LightningTestModelWithUnusedHyperparametersArg(LightningTestModelWithoutHyperparametersArg): - """It has hparams argument in constructor but is not used.""" - - def __init__(self, hparams): - super().__init__() +from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate # noqa: F401 diff --git a/tests/base/datasets.py b/tests/base/datasets.py deleted file mode 100644 index af6cb062dc80e7..00000000000000 --- a/tests/base/datasets.py +++ /dev/null @@ -1,186 +0,0 @@ -import logging -import os -import urllib.request -from typing import Tuple, Optional, Sequence - -import torch -from torch import Tensor -from torch.utils.data import Dataset - -from tests import PACKAGE_ROOT - -#: local path to test datasets -PATH_DATASETS = os.path.join(PACKAGE_ROOT, 'Datasets') - - -class MNIST(Dataset): - """ - Customized `MNIST `_ dataset for testing Pytorch Lightning - without the torchvision dependency. - - Part of the code was copied from - https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py - - Args: - root: Root directory of dataset where ``MNIST/processed/training.pt`` - and ``MNIST/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - normalize: mean and std deviation of the MNIST dataset. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - - Examples: - >>> dataset = MNIST(download=True) - >>> len(dataset) - 60000 - >>> torch.bincount(dataset.targets) - tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]) - """ - - RESOURCES = ( - "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", - "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", - ) - - TRAIN_FILE_NAME = 'training.pt' - TEST_FILE_NAME = 'test.pt' - cache_folder_name = 'complete' - - def __init__(self, root: str = PATH_DATASETS, train: bool = True, - normalize: tuple = (0.5, 1.0), download: bool = True): - super().__init__() - self.root = root - self.train = train # training set or test set - self.normalize = normalize - - self.prepare_data(download) - - if not self._check_exists(self.cached_folder_path): - raise RuntimeError('Dataset not found.') - - data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME - self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file)) - - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: - img = self.data[idx].float().unsqueeze(0) - target = int(self.targets[idx]) - - if self.normalize is not None: - img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1]) - - return img, target - - def __len__(self) -> int: - return len(self.data) - - @property - def cached_folder_path(self) -> str: - return os.path.join(self.root, 'MNIST', self.cache_folder_name) - - def _check_exists(self, data_folder: str) -> bool: - existing = True - for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): - existing = existing and os.path.isfile(os.path.join(data_folder, fname)) - return existing - - def prepare_data(self, download: bool): - if download: - self._download(self.cached_folder_path) - - def _download(self, data_folder: str) -> None: - """Download the MNIST data if it doesn't exist in cached_folder_path already.""" - - if self._check_exists(data_folder): - return - - os.makedirs(data_folder, exist_ok=True) - - for url in self.RESOURCES: - logging.info(f'Downloading {url}') - fpath = os.path.join(data_folder, os.path.basename(url)) - urllib.request.urlretrieve(url, fpath) - - -def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: - tensor = tensor.clone() - mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) - std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) - tensor.sub_(mean).div_(std) - return tensor - - -class TrialMNIST(MNIST): - """Constrain image dataset - - Args: - root: Root directory of dataset where ``MNIST/processed/training.pt`` - and ``MNIST/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - normalize: mean and std deviation of the MNIST dataset. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - num_samples: number of examples per selected class/digit - digits: list selected MNIST digits/classes - - Examples: - >>> dataset = TrialMNIST(download=True) - >>> len(dataset) - 300 - >>> sorted(set([d.item() for d in dataset.targets])) - [0, 1, 2] - >>> torch.bincount(dataset.targets) - tensor([100, 100, 100]) - """ - - def __init__(self, root: str = PATH_DATASETS, train: bool = True, - normalize: tuple = (0.5, 1.0), download: bool = False, - num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2)): - - # number of examples per class - self.num_samples = num_samples - # take just a subset of MNIST dataset - self.digits = digits if digits else list(range(10)) - - self.cache_folder_name = 'digits-' + '-'.join(str(d) for d in sorted(self.digits)) \ - + f'_nb-{self.num_samples}' - - super().__init__( - root, - train=train, - normalize=normalize, - download=download - ) - - @staticmethod - def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, - num_samples: int, digits: Sequence): - classes = {d: 0 for d in digits} - indexes = [] - for idx, target in enumerate(full_targets): - label = target.item() - if classes.get(label, float('inf')) >= num_samples: - continue - indexes.append(idx) - classes[label] += 1 - if all(classes[k] >= num_samples for k in classes): - break - data = full_data[indexes] - targets = full_targets[indexes] - return data, targets - - def prepare_data(self, download: bool) -> None: - if self._check_exists(self.cached_folder_path): - return - if download: - self._download(super().cached_folder_path) - - for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): - path_fname = os.path.join(super().cached_folder_path, fname) - assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname - data, targets = torch.load(path_fname) - data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) - torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) diff --git a/tests/base/debug.py b/tests/base/debug.py deleted file mode 100644 index 0c3b120c9366da..00000000000000 --- a/tests/base/debug.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from torch.nn import functional as F -from torch.utils.data import DataLoader - -import pytorch_lightning as pl -from tests.base.datasets import TrialMNIST - - -# from test_models import assert_ok_test_acc, load_model, \ -# clear_save_dir, get_default_logger, get_default_hparams, init_save_dir, \ -# init_checkpoint_callback, reset_seed, set_random_master_port - - -class CoolModel(pl.LightningModule): - - def __init(self): - super().__init__() - # not the best model... - self.l1 = torch.nn.Linear(28 * 28, 10) - - def forward(self, x): - return torch.relu(self.l1(x)) - - def my_loss(self, y_hat, y): - return F.cross_entropy(y_hat, y) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - return {'training_loss': self.my_loss(y_hat, y)} - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - return {'val_loss': self.my_loss(y_hat, y)} - - def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x for x in outputs['val_loss']]).mean() - return avg_loss - - def configure_optimizers(self): - return [torch.optim.Adam(self.parameters(), lr=0.02)] - - def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, num_samples=100), batch_size=16) - - def val_dataloader(self): - return DataLoader(TrialMNIST(train=False, num_samples=50), batch_size=16) - - def test_dataloader(self): - return DataLoader(TrialMNIST(train=False, num_samples=50), batch_size=16) diff --git a/tests/base/eval_model_optimizers.py b/tests/base/eval_model_optimizers.py deleted file mode 100644 index 2fd9b104a06d9a..00000000000000 --- a/tests/base/eval_model_optimizers.py +++ /dev/null @@ -1,61 +0,0 @@ -from abc import ABC - -from torch import optim - - -class ConfigureOptimizersPool(ABC): - def configure_optimizers(self): - """ - return whatever optimizers we want here. - :return: list of optimizers - """ - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return optimizer - - def configure_optimizers__empty(self): - return None - - def configure_optimizers__lbfgs(self): - """ - return whatever optimizers we want here. - :return: list of optimizers - """ - optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - return optimizer - - def configure_optimizers__multiple_optimizers(self): - """ - return whatever optimizers we want here. - :return: list of optimizers - """ - # try no scheduler for this model (testing purposes) - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return optimizer1, optimizer2 - - def configure_optimizers__single_scheduler(self): - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) - return [optimizer], [lr_scheduler] - - def configure_optimizers__multiple_schedulers(self): - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) - lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - - return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] - - def configure_optimizers__mixed_scheduling(self): - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1) - lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - - return [optimizer1, optimizer2], \ - [{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2] - - def configure_optimizers__reduce_lr_on_plateau(self): - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) - return [optimizer], [lr_scheduler] diff --git a/tests/base/eval_model_template.py b/tests/base/eval_model_template.py deleted file mode 100644 index d97e8a925fc6dc..00000000000000 --- a/tests/base/eval_model_template.py +++ /dev/null @@ -1,83 +0,0 @@ -from argparse import Namespace - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from pytorch_lightning.core.lightning import LightningModule -from tests.base.datasets import TrialMNIST -from tests.base.eval_model_optimizers import ConfigureOptimizersPool -from tests.base.eval_model_test_dataloaders import TestDataloaderVariations -from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations -from tests.base.eval_model_test_steps import TestStepVariations -from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations -from tests.base.eval_model_train_steps import TrainingStepVariations -from tests.base.eval_model_utils import ModelTemplateUtils, ModelTemplateData -from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations -from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations -from tests.base.eval_model_valid_steps import ValidationStepVariations - - -class EvalModelTemplate( - ModelTemplateData, - ModelTemplateUtils, - TrainingStepVariations, - ValidationStepVariations, - ValidationEpochEndVariations, - TestStepVariations, - TestEpochEndVariations, - TrainDataloaderVariations, - ValDataloaderVariations, - TestDataloaderVariations, - ConfigureOptimizersPool, - LightningModule -): - """ - This template houses all combinations of model configurations we want to test - """ - def __init__(self, hparams: object) -> object: - """Pass in parsed HyperOptArgumentParser to the model.""" - # init superclass - super().__init__() - self.hparams = Namespace(**hparams) if isinstance(hparams, dict) else hparams - - # if you specify an example input, the summary will show input/output for each layer - self.example_input_array = torch.rand(5, 28 * 28) - - # build model - self.__build_model() - - def __build_model(self): - """ - Simple model for testing - :return: - """ - self.c_d1 = nn.Linear( - in_features=self.hparams.in_features, - out_features=self.hparams.hidden_dim - ) - self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) - self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) - - self.c_d2 = nn.Linear( - in_features=self.hparams.hidden_dim, - out_features=self.hparams.out_features - ) - - def forward(self, x): - x = self.c_d1(x) - x = torch.tanh(x) - x = self.c_d1_bn(x) - x = self.c_d1_drop(x) - - x = self.c_d2(x) - logits = F.log_softmax(x, dim=1) - - return logits - - def loss(self, labels, logits): - nll = F.nll_loss(logits, labels) - return nll - - def prepare_data(self): - _ = TrialMNIST(root=self.hparams.data_root, train=True, download=True) diff --git a/tests/base/eval_model_test_dataloaders.py b/tests/base/eval_model_test_dataloaders.py deleted file mode 100644 index fdab56994ab9e5..00000000000000 --- a/tests/base/eval_model_test_dataloaders.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod - -from tests.base.eval_model_utils import CustomInfDataloader - - -class TestDataloaderVariations(ABC): - - @abstractmethod - def dataloader(self, train: bool): - """placeholder""" - - def test_dataloader(self): - return self.dataloader(train=False) - - def test_dataloader__infinite(self): - return CustomInfDataloader(self.dataloader(train=False)) - - def test_dataloader__empty(self): - return None - - def test_dataloader__multiple(self): - return [self.dataloader(train=False), self.dataloader(train=False)] diff --git a/tests/base/eval_model_train_dataloaders.py b/tests/base/eval_model_train_dataloaders.py deleted file mode 100644 index ded46de3d6e41b..00000000000000 --- a/tests/base/eval_model_train_dataloaders.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod - -from tests.base.eval_model_utils import CustomInfDataloader - - -class TrainDataloaderVariations(ABC): - - @abstractmethod - def dataloader(self, train: bool): - """placeholder""" - - def train_dataloader(self): - return self.dataloader(train=True) - - def train_dataloader__infinite(self): - return CustomInfDataloader(self.dataloader(train=True)) - - def train_dataloader__zero_length(self): - dataloader = self.dataloader(train=True) - dataloader.dataset.data = dataloader.dataset.data[:0] - dataloader.dataset.targets = dataloader.dataset.targets[:0] - return dataloader diff --git a/tests/base/eval_model_train_steps.py b/tests/base/eval_model_train_steps.py deleted file mode 100644 index 8a4307555dccb3..00000000000000 --- a/tests/base/eval_model_train_steps.py +++ /dev/null @@ -1,44 +0,0 @@ -import math -from abc import ABC -from collections import OrderedDict - -import torch - - -class TrainingStepVariations(ABC): - """ - Houses all variations of training steps - """ - test_step_inf_loss = float('inf') - - def training_step(self, batch, batch_idx, optimizer_idx=None): - """Lightning calls this inside the training loop""" - # forward pass - x, y = batch - x = x.view(x.size(0), -1) - - y_hat = self(x) - - # calculate loss - loss_val = self.loss(y, y_hat) - - # alternate possible outputs to test - if self.trainer.batch_idx % 1 == 0: - output = OrderedDict({ - 'loss': loss_val, - 'progress_bar': {'some_val': loss_val * loss_val}, - 'log': {'train_some_val': loss_val * loss_val}, - }) - return output - - if self.trainer.batch_idx % 2 == 0: - return loss_val - - def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None): - output = self.training_step(batch, batch_idx, optimizer_idx) - if batch_idx == self.test_step_inf_loss: - if isinstance(output, dict): - output['loss'] *= torch.tensor(math.inf) # make loss infinite - else: - output /= 0 - return output diff --git a/tests/base/eval_model_utils.py b/tests/base/eval_model_utils.py deleted file mode 100644 index d3eed3cb8dc5be..00000000000000 --- a/tests/base/eval_model_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch.utils.data import DataLoader - -from tests.base.datasets import TrialMNIST - - -class ModelTemplateData: - hparams: ... - - def dataloader(self, train): - dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True) - - loader = DataLoader( - dataset=dataset, - batch_size=self.hparams.batch_size, - # test and valid shall not be shuffled - shuffle=train, - ) - return loader - - -class ModelTemplateUtils: - - def get_output_metric(self, output, name): - if isinstance(output, dict): - val = output[name] - else: # if it is 2level deep -> per dataloader and per batch - val = sum(out[name] for out in output) / len(output) - return val - - -class CustomInfDataloader: - - def __init__(self, dataloader): - self.dataloader = dataloader - self.iter = iter(dataloader) - self.count = 0 - - def __iter__(self): - self.count = 0 - return self - - def __next__(self): - if self.count >= 50: - raise StopIteration - self.count = self.count + 1 - try: - return next(self.iter) - except StopIteration: - self.iter = iter(self.dataloader) - return next(self.iter) diff --git a/tests/base/eval_model_valid_dataloaders.py b/tests/base/eval_model_valid_dataloaders.py deleted file mode 100644 index 2b760e13086fd4..00000000000000 --- a/tests/base/eval_model_valid_dataloaders.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import ABC, abstractmethod - -from tests.base.eval_model_utils import CustomInfDataloader - - -class ValDataloaderVariations(ABC): - - @abstractmethod - def dataloader(self, train: bool): - """placeholder""" - - def val_dataloader(self): - return self.dataloader(train=False) - - def val_dataloader__multiple(self): - return [self.dataloader(train=False), - self.dataloader(train=False)] - - def val_dataloader__infinite(self): - return CustomInfDataloader(self.dataloader(train=False)) diff --git a/tests/base/eval_model_valid_epoch_ends.py b/tests/base/eval_model_valid_epoch_ends.py deleted file mode 100644 index 73866451023f59..00000000000000 --- a/tests/base/eval_model_valid_epoch_ends.py +++ /dev/null @@ -1,47 +0,0 @@ -from abc import ABC - -import torch - - -class ValidationEpochEndVariations(ABC): - """ - Houses all variations of validation_epoch_end steps - """ - def validation_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs - - Args: - outputs: list of individual outputs of each validation step - """ - # if returned a scalar from validation_step, outputs is a list of tensor scalars - # we return just the average in this case (if we want) - def _mean(res, key): - # recursive mean for multilevel dicts - return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean() - - # return torch.stack(outputs).mean() - val_loss_mean = _mean(outputs, 'val_loss') - val_acc_mean = _mean(outputs, 'val_acc') - for output in outputs: - val_loss = self.get_output_metric(output, 'val_loss') - - # reduce manually when using dp - if self.trainer.use_dp or self.trainer.use_ddp2: - val_loss = torch.mean(val_loss) - val_loss_mean += val_loss - - # reduce manually when using dp - val_acc = self.get_output_metric(output, 'val_acc') - if self.trainer.use_dp or self.trainer.use_ddp2: - val_acc = torch.mean(val_acc) - - val_acc_mean += val_acc - - if outputs: # skip zero divisions - val_loss_mean /= len(outputs) - val_acc_mean /= len(outputs) - - metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - results = {'progress_bar': metrics_dict, 'log': metrics_dict} - return results diff --git a/tests/base/eval_model_valid_steps.py b/tests/base/eval_model_valid_steps.py deleted file mode 100644 index d6c9a847920541..00000000000000 --- a/tests/base/eval_model_valid_steps.py +++ /dev/null @@ -1,101 +0,0 @@ -from abc import ABC -from collections import OrderedDict - -import torch - - -class ValidationStepVariations(ABC): - """ - Houses all variations of validation steps - """ - def validation_step(self, batch, batch_idx, *args, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_val = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - }) - return output - if batch_idx % 2 == 0: - return val_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - 'test_dic': {'val_loss_a': loss_val} - }) - return output - - def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_val = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - }) - return output - if batch_idx % 2 == 0: - return val_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - 'test_dic': {'val_loss_a': loss_val} - }) - return output - if batch_idx % 5 == 0: - output = OrderedDict({ - f'val_loss_{dataloader_idx}': loss_val, - f'val_acc_{dataloader_idx}': val_acc, - }) - return output diff --git a/tests/base/mixins.py b/tests/base/mixins.py deleted file mode 100644 index fcfd93a8a445f9..00000000000000 --- a/tests/base/mixins.py +++ /dev/null @@ -1,718 +0,0 @@ -from collections import OrderedDict - -import torch -from torch import optim - - -class LightValidationStepMixin: - """ - Add val_dataloader and validation_step methods for the case - when val_dataloader returns a single dataloader - """ - - def val_dataloader(self): - return self._dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - """Lightning calls this inside the validation loop.""" - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_val = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - }) - return output - if batch_idx % 2 == 0: - return val_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - 'test_dic': {'val_loss_a': loss_val} - }) - return output - - -class LightValidationMixin(LightValidationStepMixin): - """ - Add val_dataloader, validation_step, and validation_end methods for the case - when val_dataloader returns a single dataloader - """ - - def validation_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs - - Args: - outputs: list of individual outputs of each validation step - """ - # if returned a scalar from validation_step, outputs is a list of tensor scalars - # we return just the average in this case (if we want) - # return torch.stack(outputs).mean() - val_loss_mean = 0 - val_acc_mean = 0 - for output in outputs: - val_loss = _get_output_metric(output, 'val_loss') - - # reduce manually when using dp - if self.trainer.use_dp or self.trainer.use_ddp2: - val_loss = torch.mean(val_loss) - val_loss_mean += val_loss - - # reduce manually when using dp - val_acc = _get_output_metric(output, 'val_acc') - if self.trainer.use_dp or self.trainer.use_ddp2: - val_acc = torch.mean(val_acc) - - val_acc_mean += val_acc - - val_loss_mean /= len(outputs) - val_acc_mean /= len(outputs) - - metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - results = {'progress_bar': metrics_dict, 'log': metrics_dict} - return results - - -class LightValidationStepMultipleDataloadersMixin: - """ - Add val_dataloader and validation_step methods for the case - when val_dataloader returns multiple dataloaders - """ - - def val_dataloader(self): - return [self._dataloader(train=False), self._dataloader(train=False)] - - def validation_step(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_val = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - }) - return output - if batch_idx % 2 == 0: - return val_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - 'test_dic': {'val_loss_a': loss_val} - }) - return output - if batch_idx % 5 == 0: - output = OrderedDict({ - f'val_loss_{dataloader_idx}': loss_val, - f'val_acc_{dataloader_idx}': val_acc, - }) - return output - - -class LightValidationMultipleDataloadersMixin(LightValidationStepMultipleDataloadersMixin): - """ - Add val_dataloader, validation_step, and validation_end methods for the case - when val_dataloader returns multiple dataloaders - """ - - def validation_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs - :param outputs: list of individual outputs of each validation step - :return: - """ - # if returned a scalar from validation_step, outputs is a list of tensor scalars - # we return just the average in this case (if we want) - # return torch.stack(outputs).mean() - val_loss_mean = 0 - val_acc_mean = 0 - i = 0 - for dl_output in outputs: - for output in dl_output: - val_loss = output['val_loss'] - - # reduce manually when using dp - if self.trainer.use_dp: - val_loss = torch.mean(val_loss) - val_loss_mean += val_loss - - # reduce manually when using dp - val_acc = output['val_acc'] - if self.trainer.use_dp: - val_acc = torch.mean(val_acc) - - val_acc_mean += val_acc - i += 1 - - val_loss_mean /= i - val_acc_mean /= i - - tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - result = {'progress_bar': tqdm_dict} - return result - - -class LightTrainDataloader: - """Simple train dataloader.""" - - def train_dataloader(self): - return self._dataloader(train=True) - - -class LightValidationDataloader: - """Simple validation dataloader.""" - - def val_dataloader(self): - return self._dataloader(train=False) - - -class LightTestDataloader: - """Simple test dataloader.""" - - def test_dataloader(self): - return self._dataloader(train=False) - - -class CustomInfDataloader: - def __init__(self, dataloader): - self.dataloader = dataloader - self.iter = iter(dataloader) - self.count = 0 - - def __iter__(self): - self.count = 0 - return self - - def __next__(self): - if self.count >= 50: - raise StopIteration - self.count = self.count + 1 - try: - return next(self.iter) - except StopIteration: - self.iter = iter(self.dataloader) - return next(self.iter) - - -class LightInfTrainDataloader: - """Simple test dataloader.""" - - def train_dataloader(self): - return CustomInfDataloader(self._dataloader(train=True)) - - -class LightInfValDataloader: - """Simple test dataloader.""" - - def val_dataloader(self): - return CustomInfDataloader(self._dataloader(train=False)) - - -class LightInfTestDataloader: - """Simple test dataloader.""" - - def test_dataloader(self): - return CustomInfDataloader(self._dataloader(train=False)) - - -class LightZeroLenDataloader: - """ Simple dataloader that has zero length. """ - - def train_dataloader(self): - dataloader = self._dataloader(train=True) - dataloader.dataset.data = dataloader.dataset.data[:0] - dataloader.dataset.targets = dataloader.dataset.targets[:0] - return dataloader - - -class LightEmptyTestStep: - """Empty test step.""" - - def test_step(self, *args, **kwargs): - return dict() - - -class LightTestStepMixin(LightTestDataloader): - """Test step mixin.""" - - def test_step(self, batch, batch_idx, *args, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_test = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - test_acc = torch.tensor(test_acc) - - if self.on_gpu: - test_acc = test_acc.cuda(loss_test.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_test = loss_test.unsqueeze(0) - test_acc = test_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) - return output - if batch_idx % 2 == 0: - return test_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} - }) - return output - - -class LightTestMixin(LightTestStepMixin): - """Ritch test mixin.""" - - def test_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs - :param outputs: list of individual outputs of each validation step - :return: - """ - # if returned a scalar from test_step, outputs is a list of tensor scalars - # we return just the average in this case (if we want) - # return torch.stack(outputs).mean() - test_loss_mean = 0 - test_acc_mean = 0 - for output in outputs: - test_loss = _get_output_metric(output, 'test_loss') - - # reduce manually when using dp - if self.trainer.use_dp: - test_loss = torch.mean(test_loss) - test_loss_mean += test_loss - - # reduce manually when using dp - test_acc = _get_output_metric(output, 'test_acc') - if self.trainer.use_dp: - test_acc = torch.mean(test_acc) - - test_acc_mean += test_acc - - test_loss_mean /= len(outputs) - test_acc_mean /= len(outputs) - - metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} - result = {'progress_bar': metrics_dict, 'log': metrics_dict} - return result - - -class LightTestStepMultipleDataloadersMixin: - """Test step multiple dataloaders mixin.""" - - def test_dataloader(self): - return [self._dataloader(train=False), self._dataloader(train=False)] - - def test_step(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_test = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - test_acc = torch.tensor(test_acc) - - if self.on_gpu: - test_acc = test_acc.cuda(loss_test.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_test = loss_test.unsqueeze(0) - test_acc = test_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) - return output - if batch_idx % 2 == 0: - return test_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} - }) - return output - if batch_idx % 5 == 0: - output = OrderedDict({ - f'test_loss_{dataloader_idx}': loss_test, - f'test_acc_{dataloader_idx}': test_acc, - }) - return output - - -class LightTestFitSingleTestDataloadersMixin: - """Test fit single test dataloaders mixin.""" - - def test_dataloader(self): - return self._dataloader(train=False) - - def test_step(self, batch, batch_idx, *args, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_test = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - test_acc = torch.tensor(test_acc) - - if self.on_gpu: - test_acc = test_acc.cuda(loss_test.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_test = loss_test.unsqueeze(0) - test_acc = test_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) - return output - if batch_idx % 2 == 0: - return test_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} - }) - return output - - -class LightTestFitMultipleTestDataloadersMixin: - """Test fit multiple test dataloaders mixin.""" - - def test_step(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_test = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - test_acc = torch.tensor(test_acc) - - if self.on_gpu: - test_acc = test_acc.cuda(loss_test.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_test = loss_test.unsqueeze(0) - test_acc = test_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) - return output - if batch_idx % 2 == 0: - return test_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} - }) - return output - if batch_idx % 5 == 0: - output = OrderedDict({ - f'test_loss_{dataloader_idx}': loss_test, - f'test_acc_{dataloader_idx}': test_acc, - }) - return output - - -class LightValStepFitSingleDataloaderMixin: - - def validation_step(self, batch, batch_idx, *args, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_val = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - }) - return output - if batch_idx % 2 == 0: - return val_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - 'test_dic': {'val_loss_a': loss_val} - }) - return output - - -class LightValStepFitMultipleDataloadersMixin: - - def validation_step(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Lightning calls this inside the validation loop - :param batch: - :return: - """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self(x) - - loss_val = self.loss(y, y_hat) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) - - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - - # alternate possible outputs to test - if batch_idx % 1 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - }) - return output - if batch_idx % 2 == 0: - return val_acc - - if batch_idx % 3 == 0: - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc': val_acc, - 'test_dic': {'val_loss_a': loss_val} - }) - return output - if batch_idx % 5 == 0: - output = OrderedDict({ - f'val_loss_{dataloader_idx}': loss_val, - f'val_acc_{dataloader_idx}': val_acc, - }) - return output - - -class LightTestMultipleDataloadersMixin(LightTestStepMultipleDataloadersMixin): - - def test_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs - :param outputs: list of individual outputs of each validation step - :return: - """ - # if returned a scalar from test_step, outputs is a list of tensor scalars - # we return just the average in this case (if we want) - # return torch.stack(outputs).mean() - test_loss_mean = 0 - test_acc_mean = 0 - i = 0 - for dl_output in outputs: - for output in dl_output: - test_loss = output['test_loss'] - - # reduce manually when using dp - if self.trainer.use_dp: - test_loss = torch.mean(test_loss) - test_loss_mean += test_loss - - # reduce manually when using dp - test_acc = output['test_acc'] - if self.trainer.use_dp: - test_acc = torch.mean(test_acc) - - test_acc_mean += test_acc - i += 1 - - test_loss_mean /= i - test_acc_mean /= i - - tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} - result = {'progress_bar': tqdm_dict} - return result - - -class LightTestOptimizerWithSchedulingMixin: - def configure_optimizers(self): - if self.hparams.optimizer_name == 'lbfgs': - optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - else: - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) - return [optimizer], [lr_scheduler] - - -class LightTestMultipleOptimizersWithSchedulingMixin: - def configure_optimizers(self): - if self.hparams.optimizer_name == 'lbfgs': - optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - else: - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) - lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - - return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] - - -class LightTestOptimizersWithMixedSchedulingMixin: - def configure_optimizers(self): - if self.hparams.optimizer_name == 'lbfgs': - optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - else: - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1) - lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - - return [optimizer1, optimizer2], \ - [{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2] - - -class LightTestReduceLROnPlateauMixin: - def configure_optimizers(self): - if self.hparams.optimizer_name == 'lbfgs': - optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - else: - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) - return [optimizer], [lr_scheduler] - - -class LightTestNoneOptimizerMixin: - def configure_optimizers(self): - return None - - -def _get_output_metric(output, name): - if isinstance(output, dict): - val = output[name] - else: # if it is 2level deep -> per dataloader and per batch - val = sum(out[name] for out in output) / len(output) - return val diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py new file mode 100644 index 00000000000000..39e67748f0a900 --- /dev/null +++ b/tests/base/model_optimizers.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC + +from torch import optim + + +class ConfigureOptimizersPool(ABC): + + def configure_optimizers(self): + """ + return whatever optimizers we want here. + :return: list of optimizers + """ + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + def configure_optimizers__empty(self): + return None + + def configure_optimizers__lbfgs(self): + """ + return whatever optimizers we want here. + :return: list of optimizers + """ + optimizer = optim.LBFGS(self.parameters(), lr=self.learning_rate) + return optimizer + + def configure_optimizers__adagrad(self): + optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate) + return optimizer + + def configure_optimizers__multiple_optimizers_frequency(self): + optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) + return [ + dict(optimizer=optimizer1, frequency=1), + dict(optimizer=optimizer2, frequency=5), + ] + + def configure_optimizers__single_scheduler(self): + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) + return [optimizer], [lr_scheduler] + + def configure_optimizers__multiple_schedulers(self): + optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + def configure_optimizers__param_groups(self): + param_groups = [{ + 'params': list(self.parameters())[:2], + 'lr': self.learning_rate * 0.1 + }, { + 'params': list(self.parameters())[2:], + 'lr': self.learning_rate + }] + + optimizer = optim.Adam(param_groups) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) + return [optimizer], [lr_scheduler] + + def configure_optimizers__lr_from_hparams(self): + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + return optimizer diff --git a/tests/base/model_template.py b/tests/base/model_template.py new file mode 100644 index 00000000000000..86578fef4c699f --- /dev/null +++ b/tests/base/model_template.py @@ -0,0 +1,183 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Generic, TypeVar + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS +from tests.base.model_optimizers import ConfigureOptimizersPool +from tests.base.model_test_dataloaders import TestDataloaderVariations +from tests.base.model_test_epoch_ends import TestEpochEndVariations +from tests.base.model_test_steps import TestStepVariations +from tests.base.model_train_dataloaders import TrainDataloaderVariations +from tests.base.model_train_steps import TrainingStepVariations +from tests.base.model_utilities import ModelTemplateData, ModelTemplateUtils +from tests.base.model_valid_dataloaders import ValDataloaderVariations +from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations +from tests.base.model_valid_steps import ValidationStepVariations +from tests.helpers.datasets import TrialMNIST + + +class EvalModelTemplate( + ModelTemplateData, + ModelTemplateUtils, + TrainingStepVariations, + ValidationStepVariations, + ValidationEpochEndVariations, + TestStepVariations, + TestEpochEndVariations, + TrainDataloaderVariations, + ValDataloaderVariations, + TestDataloaderVariations, + ConfigureOptimizersPool, + LightningModule, +): + """ + This template houses all combinations of model configurations we want to test + + >>> model = EvalModelTemplate() + """ + + def __init__( + self, + drop_prob: float = 0.2, + batch_size: int = 32, + in_features: int = 28 * 28, + learning_rate: float = 0.001 * 8, + optimizer_name: str = 'adam', + data_root: str = PATH_DATASETS, + out_features: int = 10, + hidden_dim: int = 1000, + b1: float = 0.5, + b2: float = 0.999, + ): + # init superclass + super().__init__() + self.save_hyperparameters() + + self.drop_prob = drop_prob + self.batch_size = batch_size + self.in_features = in_features + self.learning_rate = learning_rate + self.optimizer_name = optimizer_name + self.data_root = data_root + self.out_features = out_features + self.hidden_dim = hidden_dim + self.b1 = b1 + self.b2 = b2 + self.training_step_called = False + self.training_step_end_called = False + self.training_epoch_end_called = False + self.validation_step_called = False + self.validation_step_end_called = False + self.validation_epoch_end_called = False + self.test_step_called = False + self.test_step_end_called = False + self.test_epoch_end_called = False + + self.example_input_array = torch.rand(5, 28 * 28) + + # build model + self.__build_model() + + def __build_model(self): + """ + Simple model for testing + :return: + """ + self.c_d1 = nn.Linear(in_features=self.in_features, out_features=self.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hidden_dim) + self.c_d1_drop = nn.Dropout(self.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x): + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + + x = self.c_d2(x) + logits = F.softmax(x, dim=1) + + return logits + + def loss(self, labels, logits): + nll = F.nll_loss(logits, labels) + return nll + + def prepare_data(self): + TrialMNIST(root=self.data_root, train=True, download=True) + + @staticmethod + def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict: + args = dict( + drop_prob=0.2, + batch_size=32, + in_features=28 * 28, + learning_rate=0.001 * 8, + optimizer_name='adam', + data_root=PATH_DATASETS, + out_features=10, + hidden_dim=1000, + b1=0.5, + b2=0.999, + ) + + if continue_training: + args.update( + test_tube_do_checkpoint_load=True, + hpc_exp_number=hpc_exp_number, + ) + + return args + + +T = TypeVar('T') + + +class GenericParentEvalModelTemplate(Generic[T], EvalModelTemplate): + + def __init__( + self, + drop_prob: float, + batch_size: int, + in_features: int, + learning_rate: float, + optimizer_name: str, + data_root: str, + out_features: int, + hidden_dim: int, + b1: float, + b2: float, + ): + super().__init__( + drop_prob=drop_prob, + batch_size=batch_size, + in_features=in_features, + learning_rate=learning_rate, + optimizer_name=optimizer_name, + data_root=data_root, + out_features=out_features, + hidden_dim=hidden_dim, + b1=b1, + b2=b2, + ) + + +class GenericEvalModelTemplate(GenericParentEvalModelTemplate[int]): + pass diff --git a/tests/base/model_test_dataloaders.py b/tests/base/model_test_dataloaders.py new file mode 100644 index 00000000000000..a22d46f35933e2 --- /dev/null +++ b/tests/base/model_test_dataloaders.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader + + +class TestDataloaderVariations(ABC): + + @abstractmethod + def dataloader(self, *args, **kwargs): + """placeholder""" + + def test_dataloader(self): + return self.dataloader(train=False) + + def test_dataloader__infinite(self): + return CustomInfDataloader(self.dataloader(train=False)) + + def test_dataloader__not_implemented_error(self): + return CustomNotImplementedErrorDataloader(self.dataloader(train=False)) + + def test_dataloader__multiple_mixed_length(self): + lengths = [50, 30, 40] + dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths] + return dataloaders diff --git a/tests/base/eval_model_test_epoch_ends.py b/tests/base/model_test_epoch_ends.py similarity index 63% rename from tests/base/eval_model_test_epoch_ends.py rename to tests/base/model_test_epoch_ends.py index fa3c3f7f4a90e0..90084298b31879 100644 --- a/tests/base/eval_model_test_epoch_ends.py +++ b/tests/base/model_test_epoch_ends.py @@ -1,13 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC import torch +from pytorch_lightning.utilities import DistributedType + class TestEpochEndVariations(ABC): def test_epoch_end(self, outputs): """ - Called at the end of validation to aggregate outputs + Called at the end of test epoch to aggregate outputs :param outputs: list of individual outputs of each validation step :return: """ @@ -20,13 +35,13 @@ def test_epoch_end(self, outputs): test_loss = self.get_output_metric(output, 'test_loss') # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = self.get_output_metric(output, 'test_acc') - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc @@ -34,13 +49,13 @@ def test_epoch_end(self, outputs): test_loss_mean /= len(outputs) test_acc_mean /= len(outputs) - metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + metrics_dict = {'test_loss': test_loss_mean, 'test_acc': test_acc_mean} result = {'progress_bar': metrics_dict, 'log': metrics_dict} return result def test_epoch_end__multiple_dataloaders(self, outputs): """ - Called at the end of validation to aggregate outputs + Called at the end of test epoch to aggregate outputs :param outputs: list of individual outputs of each validation step :return: """ @@ -55,13 +70,13 @@ def test_epoch_end__multiple_dataloaders(self, outputs): test_loss = output['test_loss'] # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = output['test_acc'] - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc @@ -70,6 +85,6 @@ def test_epoch_end__multiple_dataloaders(self, outputs): test_loss_mean /= i test_acc_mean /= i - tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + tqdm_dict = {'test_loss': test_loss_mean, 'test_acc': test_acc_mean} result = {'progress_bar': tqdm_dict} return result diff --git a/tests/base/eval_model_test_steps.py b/tests/base/model_test_steps.py similarity index 65% rename from tests/base/eval_model_test_steps.py rename to tests/base/model_test_steps.py index bf57c2815bc89c..0b81143ee57f25 100644 --- a/tests/base/eval_model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC from collections import OrderedDict @@ -15,6 +28,8 @@ def test_step(self, batch, batch_idx, *args, **kwargs): :param batch: :return: """ + self.test_step_called = True + x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) @@ -30,10 +45,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs): # alternate possible outputs to test if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) + output = OrderedDict({'test_loss': loss_test, 'test_acc': test_acc}) return output if batch_idx % 2 == 0: return test_acc @@ -42,7 +54,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs): output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} + 'test_dic': dict(test_loss_a=loss_test), }) return output @@ -67,10 +79,7 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw # alternate possible outputs to test if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) + output = OrderedDict({'test_loss': loss_test, 'test_acc': test_acc}) return output if batch_idx % 2 == 0: return test_acc @@ -79,15 +88,9 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} + 'test_dic': dict(test_loss_a=loss_test), }) return output if batch_idx % 5 == 0: - output = OrderedDict({ - f'test_loss_{dataloader_idx}': loss_test, - f'test_acc_{dataloader_idx}': test_acc, - }) + output = OrderedDict({f'test_loss_{dataloader_idx}': loss_test, f'test_acc_{dataloader_idx}': test_acc}) return output - - def test_step__empty(self, batch, batch_idx, *args, **kwargs): - return {} diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py new file mode 100644 index 00000000000000..50c85ddc3f79d1 --- /dev/null +++ b/tests/base/model_train_dataloaders.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader + + +class TrainDataloaderVariations(ABC): + + @abstractmethod + def dataloader(self, train: bool, *args, **kwargs): + """placeholder""" + + def train_dataloader(self): + return self.dataloader(train=True) + + def train_dataloader__infinite(self): + return CustomInfDataloader(self.dataloader(train=True)) + + def train_dataloader__not_implemented_error(self): + return CustomNotImplementedErrorDataloader(self.dataloader(train=True)) + + def train_dataloader__zero_length(self): + dataloader = self.dataloader(train=True) + dataloader.dataset.data = dataloader.dataset.data[:0] + dataloader.dataset.targets = dataloader.dataset.targets[:0] + return dataloader + + def train_dataloader__multiple_mapping(self): + """Return a mapping loaders with different lengths""" + return { + 'a': self.dataloader(train=True, num_samples=100), + 'b': self.dataloader(train=True, num_samples=50), + } diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py new file mode 100644 index 00000000000000..2a4161a23e0538 --- /dev/null +++ b/tests/base/model_train_steps.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from abc import ABC +from collections import OrderedDict + +import torch + + +class TrainingStepVariations(ABC): + """ + Houses all variations of training steps + """ + + test_step_inf_loss = float('inf') + + def training_step(self, batch, batch_idx, optimizer_idx=None): + """Lightning calls this inside the training loop""" + self.training_step_called = True + + # forward pass + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_train = self.loss(y, y_hat) + log_train = loss_train + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_train = log_train.item() + + output = OrderedDict({ + 'loss': loss_train, + 'progress_bar': dict(some_val=log_train * log_train), + 'log': dict(train_some_val=log_train * log_train), + }) + return output + + def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None): + output = self.training_step(batch, batch_idx, optimizer_idx) + if batch_idx == self.test_step_inf_loss: + if isinstance(output, dict): + output['loss'] *= torch.tensor(math.inf) # make loss infinite + else: + output /= 0 + return output + + def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): + """Training step for multiple train loaders""" + + assert isinstance(batch, dict) + assert len(batch) == 2 + assert 'a' in batch and 'b' in batch + + # forward pass + x, y = batch['a'] + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + log_val = loss_val + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_val = log_val.item() + + output = OrderedDict({ + 'loss': loss_val, + 'progress_bar': { + 'some_val': log_val * log_val + }, + 'log': { + 'train_some_val': log_val * log_val + }, + }) + return output diff --git a/tests/base/model_utilities.py b/tests/base/model_utilities.py new file mode 100644 index 00000000000000..6c5da43b0611ef --- /dev/null +++ b/tests/base/model_utilities.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch.utils.data import DataLoader + +from tests.helpers.datasets import TrialMNIST + + +class ModelTemplateData: + + def dataloader(self, train: bool, num_samples: int = 100): + dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True) + + loader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + num_workers=0, + shuffle=train, + ) + return loader + + +class ModelTemplateUtils: + + def get_output_metric(self, output, name): + if isinstance(output, dict): + val = output[name] + else: # if it is 2level deep -> per dataloader and per batch + val = sum(out[name] for out in output) / len(output) + return val diff --git a/tests/base/model_valid_dataloaders.py b/tests/base/model_valid_dataloaders.py new file mode 100644 index 00000000000000..ab91b25ba02a6c --- /dev/null +++ b/tests/base/model_valid_dataloaders.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from tests.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader + + +class ValDataloaderVariations(ABC): + + @abstractmethod + def dataloader(self, *args, **kwargs): + """placeholder""" + + def val_dataloader(self): + return self.dataloader(train=False) + + def val_dataloader__multiple_mixed_length(self): + lengths = [100, 30] + dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths] + return dataloaders + + def val_dataloader__multiple(self): + return [ + self.dataloader(train=False), + self.dataloader(train=False), + ] + + def val_dataloader__infinite(self): + return CustomInfDataloader(self.dataloader(train=False)) + + def val_dataloader__not_implemented_error(self): + return CustomNotImplementedErrorDataloader(self.dataloader(train=False)) diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py new file mode 100644 index 00000000000000..7b83670acacef3 --- /dev/null +++ b/tests/base/model_valid_epoch_ends.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC + +import torch + + +class ValidationEpochEndVariations(ABC): + """ + Houses all variations of validation_epoch_end steps + """ + + def validation_epoch_end(self, outputs): + """ + Called at the end of validation to aggregate outputs + + Args: + outputs: list of individual outputs of each validation step + """ + + # if returned a scalar from validation_step, outputs is a list of tensor scalars + # we return just the average in this case (if we want) + def _mean(res, key): + # recursive mean for multilevel dicts + return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean() + + val_loss_mean = _mean(outputs, 'val_loss') + val_acc_mean = _mean(outputs, 'val_acc') + + # alternate between tensor and scalar + if self.current_epoch % 2 == 0: + val_loss_mean = val_loss_mean.item() + val_acc_mean = val_acc_mean.item() + + self.log('early_stop_on', val_loss_mean, prog_bar=True) + self.log('val_acc', val_acc_mean, prog_bar=True) + + def validation_epoch_end__multiple_dataloaders(self, outputs): + """ + Called at the end of validation to aggregate outputs + + Args: + outputs: list of individual outputs of each validation step + """ + + # if returned a scalar from validation_step, outputs is a list of tensor scalars + # we return just the average in this case (if we want) + def _mean(res, key): + return torch.stack([x[key] for x in res]).mean() + + pbar = {} + logs = {} + for dl_output_list in outputs: + output_keys = dl_output_list[0].keys() + output_keys = [x for x in output_keys if 'val_' in x] + for key in output_keys: + metric_out = _mean(dl_output_list, key) + pbar[key] = metric_out + logs[key] = metric_out + + results = { + 'val_loss': torch.stack([v for k, v in pbar.items() if k.startswith('val_loss')]).mean(), + 'progress_bar': pbar, + 'log': logs, + } + return results diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py new file mode 100644 index 00000000000000..554d76253e4db9 --- /dev/null +++ b/tests/base/model_valid_steps.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from collections import OrderedDict + +import torch + + +class ValidationStepVariations(ABC): + """ + Houses all variations of validation steps + """ + + def validation_step(self, batch, batch_idx, *args, **kwargs): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + self.validation_step_called = True + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc).type_as(x) + + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': dict(val_loss_a=loss_val), + }) + return output + + def validation_step__dp(self, batch, batch_idx, *args, **kwargs): + self.validation_step_called = True + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x.to(self.device)) + + y = y.to(y_hat.device) + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc).type_as(x) + + self.log('val_loss', loss_val) + self.log('val_acc', val_acc) + return loss_val + + def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc).type_as(x) + + output = OrderedDict({ + f'val_loss_{dataloader_idx}': loss_val, + f'val_acc_{dataloader_idx}': val_acc, + }) + return output diff --git a/tests/base/models.py b/tests/base/models.py deleted file mode 100644 index 4d39c5150b0352..00000000000000 --- a/tests/base/models.py +++ /dev/null @@ -1,293 +0,0 @@ -from collections import OrderedDict -from typing import Dict - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import optim -from torch.utils.data import DataLoader - -from tests.base import EvalModelTemplate -from tests.base.datasets import TrialMNIST - -try: - from test_tube import HyperOptArgumentParser -except ImportError: - # TODO: this should be discussed and moved out of this package - raise ImportError('Missing test-tube package.') - -from pytorch_lightning.core.lightning import LightningModule - - -class DictHparamsModel(LightningModule): - - def __init__(self, hparams: Dict): - super().__init__() - self.hparams = hparams - self.l1 = torch.nn.Linear(hparams.get('in_features'), hparams['out_features']) - - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - return {'loss': F.cross_entropy(y_hat, y)} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) - - -class TestModelBase(LightningModule): - """Base LightningModule for testing. Implements only the required interface.""" - - def __init__(self, hparams, force_remove_distributed_sampler: bool = False): - """Pass in parsed HyperOptArgumentParser to the model.""" - # init superclass - super().__init__() - self.hparams = hparams - - self.batch_size = hparams.batch_size - - # if you specify an example input, the summary will show input/output for each layer - self.example_input_array = torch.rand(5, 28 * 28) - - # remove to test warning for dist sampler - self.force_remove_distributed_sampler = force_remove_distributed_sampler - - # build model - self.__build_model() - - # --------------------- - # MODEL SETUP - # --------------------- - def __build_model(self): - """Layout model.""" - self.c_d1 = nn.Linear(in_features=self.hparams.in_features, - out_features=self.hparams.hidden_dim) - self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) - self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) - - self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, - out_features=self.hparams.out_features) - - # --------------------- - # TRAINING - # --------------------- - def forward(self, x): - """No special modification required for lightning, define as you normally would.""" - x = self.c_d1(x) - x = torch.tanh(x) - x = self.c_d1_bn(x) - x = self.c_d1_drop(x) - - x = self.c_d2(x) - logits = F.log_softmax(x, dim=1) - - return logits - - def loss(self, labels, logits): - nll = F.nll_loss(logits, labels) - return nll - - def training_step(self, batch, batch_idx, optimizer_idx=None): - """Lightning calls this inside the training loop""" - # forward pass - x, y = batch - x = x.view(x.size(0), -1) - - y_hat = self(x) - - # calculate loss - loss_val = self.loss(y, y_hat) - - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp: - loss_val = loss_val.unsqueeze(0) - - # alternate possible outputs to test - if self.trainer.batch_idx % 1 == 0: - output = OrderedDict({ - 'loss': loss_val, - 'progress_bar': {'some_val': loss_val * loss_val}, - 'log': {'train_some_val': loss_val * loss_val}, - }) - - return output - if self.trainer.batch_idx % 2 == 0: - return loss_val - - # --------------------- - # TRAINING SETUP - # --------------------- - def configure_optimizers(self): - """ - return whatever optimizers we want here. - :return: list of optimizers - """ - # try no scheduler for this model (testing purposes) - if self.hparams.optimizer_name == 'lbfgs': - optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) - else: - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) - return [optimizer], [scheduler] - - def prepare_data(self): - _ = TrialMNIST(root=self.hparams.data_root, train=True, download=True) - - def _dataloader(self, train): - # init data generators - dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True) - - # when using multi-node we need to add the datasampler - batch_size = self.hparams.batch_size - - loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=train - ) - - return loader - - -class Generator(nn.Module): - def __init__(self, latent_dim, img_shape): - super().__init__() - self.img_shape = img_shape - - def block(in_feat, out_feat, normalize=True): - layers = [nn.Linear(in_feat, out_feat)] - if normalize: - layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - - self.model = nn.Sequential( - *block(latent_dim, 128, normalize=False), - *block(128, 256), - *block(256, 512), - *block(512, 1024), - nn.Linear(1024, int(np.prod(img_shape))), - nn.Tanh() - ) - - def forward(self, z): - img = self.model(z) - img = img.view(img.size(0), *self.img_shape) - return img - - -class Discriminator(nn.Module): - def __init__(self, img_shape): - super().__init__() - - self.model = nn.Sequential( - nn.Linear(int(np.prod(img_shape)), 512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 1), - nn.Sigmoid(), - ) - - def forward(self, img): - img_flat = img.view(img.size(0), -1) - validity = self.model(img_flat) - - return validity - - -class TestGAN(LightningModule): - """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" - - def __init__(self, hparams): - super().__init__() - self.hparams = hparams - - # networks - mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=hparams.hidden_dim, img_shape=mnist_shape) - self.discriminator = Discriminator(img_shape=mnist_shape) - - # cache for generated images - self.generated_imgs = None - self.last_imgs = None - - def forward(self, z): - return self.generator(z) - - def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y) - - def training_step(self, batch, batch_idx, optimizer_idx=None): - imgs, _ = batch - self.last_imgs = imgs - - # train generator - if optimizer_idx == 0: - # sample noise - z = torch.randn(imgs.shape[0], self.hparams.hidden_dim) - z = z.type_as(imgs) - - # generate images - self.generated_imgs = self(z) - - # ground truth result (ie: all fake) - # put on GPU because we created this tensor inside training_loop - valid = torch.ones(imgs.size(0), 1) - valid = valid.type_as(imgs) - - # adversarial loss is binary cross-entropy - g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) - tqdm_dict = {'g_loss': g_loss} - output = OrderedDict({ - 'loss': g_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output - - # train discriminator - if optimizer_idx == 1: - # Measure discriminator's ability to classify real from generated samples - - # how well can it label as real? - valid = torch.ones(imgs.size(0), 1) - valid = valid.type_as(imgs) - - real_loss = self.adversarial_loss(self.discriminator(imgs), valid) - - # how well can it label as fake? - fake = torch.zeros(imgs.size(0), 1) - fake = fake.type_as(fake) - - fake_loss = self.adversarial_loss( - self.discriminator(self.generated_imgs.detach()), fake) - - # discriminator loss is the average of these - d_loss = (real_loss + fake_loss) / 2 - tqdm_dict = {'d_loss': d_loss} - output = OrderedDict({ - 'loss': d_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output - - def configure_optimizers(self): - lr = self.hparams.learning_rate - b1 = self.hparams.b1 - b2 = self.hparams.b2 - - opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) - opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) - return [opt_g, opt_d], [] - - def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) diff --git a/tests/base/utils.py b/tests/base/utils.py deleted file mode 100644 index f27d0bbdcb39c9..00000000000000 --- a/tests/base/utils.py +++ /dev/null @@ -1,231 +0,0 @@ -import os -from argparse import Namespace - -import numpy as np -import torch - -# from pl_examples import LightningTemplateModel -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger -from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS -from tests.base import LightningTestModel, EvalModelTemplate -from tests.base.datasets import PATH_DATASETS - - -def assert_speed_parity(pl_times, pt_times, num_epochs): - - # assert speeds - max_diff_per_epoch = 0.65 - pl_times = np.asarray(pl_times) - pt_times = np.asarray(pt_times) - diffs = pl_times - pt_times - diffs = diffs / num_epochs - - assert np.alltrue(diffs < max_diff_per_epoch), \ - f"lightning was slower than PT (threshold {max_diff_per_epoch})" - - -def run_model_test_without_loggers(trainer_options, model, min_acc=0.50): - reset_seed() - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - # correct result and ok accuracy - assert result == 1, 'amp + ddp model failed to complete' - - # test model loading - pretrained_model = load_model(trainer.logger, - trainer.checkpoint_callback.dirpath, - path_expt=trainer_options.get('default_root_dir')) - - # test new model accuracy - test_loaders = model.test_dataloader() - if not isinstance(test_loaders, list): - test_loaders = [test_loaders] - - for dataloader in test_loaders: - run_prediction(dataloader, pretrained_model, min_acc=min_acc) - - if trainer.use_ddp: - # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() - - -def run_model_test(trainer_options, model, on_gpu=True, version=None, with_hpc=True): - reset_seed() - save_dir = trainer_options['default_root_dir'] - - # logger file to get meta - logger = get_default_logger(save_dir, version=version) - trainer_options.update(logger=logger) - - if 'checkpoint_callback' not in trainer_options: - # logger file to get weights - checkpoint = init_checkpoint_callback(logger) - trainer_options.update(checkpoint_callback=checkpoint) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - # correct result and ok accuracy - assert result == 1, 'amp + ddp model failed to complete' - - # test model loading - pretrained_model = load_model(logger, trainer.checkpoint_callback.dirpath) - - # test new model accuracy - test_loaders = model.test_dataloader() - if not isinstance(test_loaders, list): - test_loaders = [test_loaders] - - [run_prediction(dataloader, pretrained_model) for dataloader in test_loaders] - - if with_hpc: - if trainer.use_ddp or trainer.use_ddp2: - # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ - trainer.init_optimizers(pretrained_model) - - # test HPC loading / saving - trainer.hpc_save(save_dir, logger) - trainer.hpc_load(save_dir, on_gpu=on_gpu) - - -def get_default_hparams(continue_training=False, hpc_exp_number=0): - args = { - 'drop_prob': 0.2, - 'batch_size': 32, - 'in_features': 28 * 28, - 'learning_rate': 0.001 * 8, - 'optimizer_name': 'adam', - 'data_root': PATH_DATASETS, - 'out_features': 10, - 'hidden_dim': 1000, - 'b1': 0.5, - 'b2': 0.999, - } - - if continue_training: - args.update( - test_tube_do_checkpoint_load=True, - hpc_exp_number=hpc_exp_number, - ) - - hparams = Namespace(**args) - return hparams - - -def get_default_logger(save_dir, version=None): - # set up logger object without actually saving logs - logger = TensorBoardLogger(save_dir, name='lightning_logs', version=version) - return logger - - -def get_data_path(expt_logger, path_dir=None): - # some calls contain only experiment not complete logger - expt = expt_logger.experiment if hasattr(expt_logger, 'experiment') else expt_logger - # each logger has to have these attributes - name, version = expt_logger.name, expt_logger.version - # only the test-tube experiment has such attribute - if hasattr(expt, 'get_data_path'): - return expt.get_data_path(name, version) - # the other experiments... - if not path_dir: - if hasattr(expt_logger, 'save_dir') and expt_logger.save_dir: - path_dir = expt_logger.save_dir - else: - path_dir = TEMP_PATH - path_expt = os.path.join(path_dir, name, 'version_%s' % version) - # try if the new sub-folder exists, typical case for test-tube - if not os.path.isdir(path_expt): - path_expt = path_dir - return path_expt - - -def load_model(logger, root_weights_dir, module_class=LightningTestModel, path_expt=None): - # load trained model - path_expt_dir = get_data_path(logger, path_dir=path_expt) - tags_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_CSV_TAGS) - - checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x] - weights_dir = os.path.join(root_weights_dir, checkpoints[0]) - - trained_model = module_class.load_from_checkpoint( - checkpoint_path=weights_dir, - tags_csv=tags_path - ) - - assert trained_model is not None, 'loading model failed' - - return trained_model - - -def load_model_from_checkpoint(root_weights_dir, module_class=LightningTestModel): - # load trained model - checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x] - weights_dir = os.path.join(root_weights_dir, checkpoints[0]) - - trained_model = module_class.load_from_checkpoint( - checkpoint_path=weights_dir, - ) - - assert trained_model is not None, 'loading model failed' - - return trained_model - - -def run_prediction(dataloader, trained_model, dp=False, min_acc=0.5): - # run prediction on 1 batch - for batch in dataloader: - break - - x, y = batch - x = x.view(x.size(0), -1) - - if dp: - output = trained_model(batch, 0) - acc = output['val_acc'] - acc = torch.mean(acc).item() - - else: - y_hat = trained_model(x) - - # acc - labels_hat = torch.argmax(y_hat, dim=1) - acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - acc = torch.tensor(acc) - acc = acc.item() - - assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})" - - -def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): - # this model should get 0.80+ acc - acc = trainer.progress_bar_dict[key] - assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" - - -def reset_seed(): - seed = RANDOM_SEEDS.pop() - torch.manual_seed(seed) - np.random.seed(seed) - - -def set_random_master_port(): - reset_seed() - port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) - - -def init_checkpoint_callback(logger, path_dir=None): - exp_path = get_data_path(logger, path_dir=path_dir) - ckpt_dir = os.path.join(exp_path, 'checkpoints') - os.mkdir(ckpt_dir) - checkpoint = ModelCheckpoint(ckpt_dir) - return checkpoint diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py new file mode 100644 index 00000000000000..df0eab31aac378 --- /dev/null +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -0,0 +1,136 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Callback, Trainer +from tests.helpers.boring_model import BoringModel + + +@pytest.mark.parametrize("single_cb", [False, True]) +def test_train_step_no_return(tmpdir, single_cb: bool): + """ + Tests that only training_step can be used + """ + + class CB(Callback): + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + d = outputs[0][0] + assert 'minimize' in d + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + assert 'x' in outputs + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + assert 'x' in outputs + + def on_train_epoch_end(self, trainer, pl_module, outputs): + d = outputs[0] + assert len(d) == trainer.num_training_batches + + class TestModel(BoringModel): + + def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: + d = outputs[0][0] + assert 'minimize' in d + + def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: + assert 'x' in outputs + + def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: + assert 'x' in outputs + + def on_train_epoch_end(self, outputs) -> None: + d = outputs[0] + assert len(d) == self.trainer.num_training_batches + + model = TestModel() + + trainer = Trainer( + callbacks=CB() if single_cb else [CB()], + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + assert any(isinstance(c, CB) for c in trainer.callbacks) + + results = trainer.fit(model) + assert results + + +def test_on_val_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + weights_summary=None, + ) + + trainer.test(model) + + +def test_free_memory_on_eval_outputs(tmpdir): + + class CB(Callback): + + def on_epoch_end(self, trainer, pl_module): + assert len(trainer.evaluation_loop.outputs) == 0 + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 884fc82e13e28c..a30b4fe0f609bf 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,332 +1,270 @@ -import pytest -import tests.base.utils as tutils -from pytorch_lightning import Callback -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger -from tests.base import EvalModelTemplate -from pathlib import Path - - -def test_trainer_callback_system(tmpdir): - """Test the callback system.""" - - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - - def _check_args(trainer, pl_module): - assert isinstance(trainer, Trainer) - assert isinstance(pl_module, LightningModule) - - class TestCallback(Callback): - def __init__(self): - super().__init__() - self.on_init_start_called = False - self.on_init_end_called = False - self.on_sanity_check_start_called = False - self.on_sanity_check_end_called = False - self.on_epoch_start_called = False - self.on_epoch_end_called = False - self.on_batch_start_called = False - self.on_batch_end_called = False - self.on_validation_batch_start_called = False - self.on_validation_batch_end_called = False - self.on_test_batch_start_called = False - self.on_test_batch_end_called = False - self.on_train_start_called = False - self.on_train_end_called = False - self.on_validation_start_called = False - self.on_validation_end_called = False - self.on_test_start_called = False - self.on_test_end_called = False - - def on_init_start(self, trainer): - assert isinstance(trainer, Trainer) - self.on_init_start_called = True - - def on_init_end(self, trainer): - assert isinstance(trainer, Trainer) - self.on_init_end_called = True - - def on_sanity_check_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_sanity_check_start_called = True - - def on_sanity_check_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_sanity_check_end_called = True - - def on_epoch_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_epoch_start_called = True - - def on_epoch_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_epoch_end_called = True - - def on_batch_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_batch_start_called = True - - def on_batch_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_batch_end_called = True - - def on_validation_batch_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_validation_batch_start_called = True - - def on_validation_batch_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_validation_batch_end_called = True - - def on_test_batch_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_test_batch_start_called = True - - def on_test_batch_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_test_batch_end_called = True - - def on_train_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_train_start_called = True - - def on_train_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_train_end_called = True - - def on_validation_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_validation_start_called = True - - def on_validation_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_validation_end_called = True - - def on_test_start(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_test_start_called = True - - def on_test_end(self, trainer, pl_module): - _check_args(trainer, pl_module) - self.on_test_end_called = True - - test_callback = TestCallback() - - trainer_options = dict( - callbacks=[test_callback], +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import ANY, call, MagicMock, Mock + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +@mock.patch("torch.save") # need to mock torch.save or we get pickle error +def test_trainer_callback_hook_system_fit(_, tmpdir): + """Test the callback hook system for fit.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2, + limit_val_batches=1, + limit_train_batches=3, progress_bar_refresh_rate=0, ) - assert not test_callback.on_init_start_called - assert not test_callback.on_init_end_called - assert not test_callback.on_sanity_check_start_called - assert not test_callback.on_sanity_check_end_called - assert not test_callback.on_epoch_start_called - assert not test_callback.on_epoch_start_called - assert not test_callback.on_batch_start_called - assert not test_callback.on_batch_end_called - assert not test_callback.on_validation_batch_start_called - assert not test_callback.on_validation_batch_end_called - assert not test_callback.on_test_batch_start_called - assert not test_callback.on_test_batch_end_called - assert not test_callback.on_train_start_called - assert not test_callback.on_train_end_called - assert not test_callback.on_validation_start_called - assert not test_callback.on_validation_end_called - assert not test_callback.on_test_start_called - assert not test_callback.on_test_end_called + # check that only the to calls exists + assert trainer.callbacks[0] == callback_mock + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + ] # fit model - trainer = Trainer(**trainer_options) - - assert trainer.callbacks[0] == test_callback - assert test_callback.on_init_start_called - assert test_callback.on_init_end_called - assert not test_callback.on_sanity_check_start_called - assert not test_callback.on_sanity_check_end_called - assert not test_callback.on_epoch_start_called - assert not test_callback.on_epoch_start_called - assert not test_callback.on_batch_start_called - assert not test_callback.on_batch_end_called - assert not test_callback.on_validation_batch_start_called - assert not test_callback.on_validation_batch_end_called - assert not test_callback.on_test_batch_start_called - assert not test_callback.on_test_batch_end_called - assert not test_callback.on_train_start_called - assert not test_callback.on_train_end_called - assert not test_callback.on_validation_start_called - assert not test_callback.on_validation_end_called - assert not test_callback.on_test_start_called - assert not test_callback.on_test_end_called - trainer.fit(model) - assert test_callback.on_init_start_called - assert test_callback.on_init_end_called - assert test_callback.on_sanity_check_start_called - assert test_callback.on_sanity_check_end_called - assert test_callback.on_epoch_start_called - assert test_callback.on_epoch_start_called - assert test_callback.on_batch_start_called - assert test_callback.on_batch_end_called - assert test_callback.on_validation_batch_start_called - assert test_callback.on_validation_batch_end_called - assert test_callback.on_train_start_called - assert test_callback.on_train_end_called - assert test_callback.on_validation_start_called - assert test_callback.on_validation_end_called - assert not test_callback.on_test_batch_start_called - assert not test_callback.on_test_batch_end_called - assert not test_callback.on_test_start_called - assert not test_callback.on_test_end_called - - test_callback = TestCallback() - trainer_options.update(callbacks=[test_callback]) - trainer = Trainer(**trainer_options) - trainer.test(model) - - assert test_callback.on_test_batch_start_called - assert test_callback.on_test_batch_end_called - assert test_callback.on_test_start_called - assert test_callback.on_test_end_called - assert not test_callback.on_validation_start_called - assert not test_callback.on_validation_end_called - assert not test_callback.on_validation_batch_end_called - assert not test_callback.on_validation_batch_start_called - - -def test_early_stopping_no_val_step(tmpdir): - """Test that early stopping callback falls back to training metrics when no validation defined.""" - - class CurrentModel(EvalModelTemplate): - def training_step(self, *args, **kwargs): - output = super().training_step(*args, **kwargs) - output.update({'my_train_metric': output['loss']}) # could be anything else - return output + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'fit'), + call.on_configure_sharded_model(trainer, model), + call.on_fit_start(trainer, model), + call.on_pretrain_routine_start(trainer, model), + call.on_pretrain_routine_end(trainer, model), + call.on_sanity_check_start(trainer, model), + call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.on_sanity_check_end(trainer, model), + call.on_train_start(trainer, model), + call.on_epoch_start(trainer, model), + call.on_train_epoch_start(trainer, model), + call.on_batch_start(trainer, model), + call.on_train_batch_start(trainer, model, ANY, 0, 0), + call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), + call.on_after_backward(trainer, model), + call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_batch_end(trainer, model), + call.on_batch_start(trainer, model), + call.on_train_batch_start(trainer, model, ANY, 1, 0), + call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), + call.on_after_backward(trainer, model), + call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_batch_end(trainer, model), + call.on_batch_start(trainer, model), + call.on_train_batch_start(trainer, model, ANY, 2, 0), + call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), + call.on_after_backward(trainer, model), + call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), + call.on_batch_end(trainer, model), + call.on_train_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), + call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC + call.on_train_end(trainer, model), + call.on_fit_end(trainer, model), + call.teardown(trainer, model, 'fit'), + ] + + +def test_trainer_callback_hook_system_test(tmpdir): + """Test the callback hook system for test.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_test_batches=2, + progress_bar_refresh_rate=0, + ) - model = CurrentModel(tutils.get_default_hparams()) - model.validation_step = None - model.val_dataloader = None + trainer.test(model) - stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'test'), + call.on_configure_sharded_model(trainer, model), + call.on_test_start(trainer, model), + call.on_epoch_start(trainer, model), + call.on_test_epoch_start(trainer, model), + call.on_test_batch_start(trainer, model, ANY, 0, 0), + call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_test_batch_start(trainer, model, ANY, 1, 0), + call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_test_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), + call.on_test_end(trainer, model), + call.teardown(trainer, model, 'test'), + ] + + +def test_trainer_callback_hook_system_validate(tmpdir): + """Test the callback hook system for validate.""" + + model = BoringModel() + callback_mock = MagicMock() trainer = Trainer( default_root_dir=tmpdir, - early_stop_callback=stopping, - overfit_pct=0.20, - max_epochs=5, + callbacks=[callback_mock], + max_epochs=1, + limit_val_batches=2, + progress_bar_refresh_rate=0, ) - result = trainer.fit(model) - assert result == 1, 'training failed to complete' - assert trainer.current_epoch < trainer.max_epochs + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'validate'), + call.on_configure_sharded_model(trainer, model), + call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_batch_start(trainer, model, ANY, 1, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), + call.on_validation_epoch_end(trainer, model, ANY), + call.on_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.teardown(trainer, model, 'validate'), + ] -def test_pickling(tmpdir): - import pickle - early_stopping = EarlyStopping() - ckpt = ModelCheckpoint(tmpdir) +# TODO: add callback tests for predict and tune - early_stopping_pickled = pickle.dumps(early_stopping) - ckpt_pickled = pickle.dumps(ckpt) - early_stopping_loaded = pickle.loads(early_stopping_pickled) - ckpt_loaded = pickle.loads(ckpt_pickled) +def test_callbacks_configured_in_model(tmpdir): + """ Test the callback system with callbacks added through the model hook. """ - assert vars(early_stopping) == vars(early_stopping_loaded) - assert vars(ckpt) == vars(ckpt_loaded) + model_callback_mock = Mock() + trainer_callback_mock = Mock() + class TestModel(BoringModel): -@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) -def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ - tutils.reset_seed() - model = EvalModelTemplate(tutils.get_default_hparams()) + def configure_callbacks(self): + return [model_callback_mock] + + model = TestModel() + trainer_options = dict( + default_root_dir=tmpdir, + checkpoint_callback=False, + fast_dev_run=True, + progress_bar_refresh_rate=0, + ) - checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) + def assert_expected_calls(_trainer, model_callback, trainer_callback): + # some methods in callbacks configured through model won't get called + uncalled_methods = [ + call.on_init_start(_trainer), + call.on_init_end(_trainer), + ] + for uncalled in uncalled_methods: + assert uncalled not in model_callback.method_calls + + # assert that the rest of calls are the same as for trainer callbacks + expected_calls = [m for m in trainer_callback.method_calls if m not in uncalled_methods] + assert expected_calls + assert model_callback.method_calls == expected_calls + + # .fit() + trainer_options.update(callbacks=[trainer_callback_mock]) + trainer = Trainer(**trainer_options) - trainer = Trainer(default_root_dir=tmpdir, - checkpoint_callback=checkpoint, - overfit_pct=0.20, - max_epochs=5 - ) + assert trainer_callback_mock in trainer.callbacks + assert model_callback_mock not in trainer.callbacks trainer.fit(model) - # These should be different if the dirpath has be overridden - assert trainer.ckpt_path != trainer.default_root_dir + assert model_callback_mock in trainer.callbacks + assert trainer.callbacks[-1] == model_callback_mock + assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) + # .test() + for fn in ("test", "validate"): + model_callback_mock.reset_mock() + trainer_callback_mock.reset_mock() -@pytest.mark.parametrize( - 'logger_version,expected', - [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], -) -def test_model_checkpoint_path(tmpdir, logger_version, expected): - """Test that "version_" prefix is only added when logger's version is an integer""" - tutils.reset_seed() - model = EvalModelTemplate(tutils.get_default_hparams()) - logger = TensorBoardLogger(str(tmpdir), version=logger_version) + trainer_options.update(callbacks=[trainer_callback_mock]) + trainer = Trainer(**trainer_options) - trainer = Trainer( - default_root_dir=tmpdir, - overfit_pct=0.2, - max_epochs=5, - logger=logger - ) - trainer.fit(model) + trainer_fn = getattr(trainer, fn) + trainer_fn(model) - ckpt_version = Path(trainer.ckpt_path).parent.name - assert ckpt_version == expected + assert model_callback_mock in trainer.callbacks + assert trainer.callbacks[-1] == model_callback_mock + assert_expected_calls(trainer, model_callback_mock, trainer_callback_mock) -def test_lr_logger_single_lr(tmpdir): - """ Test that learning rates are extracted and logged for single lr scheduler""" - tutils.reset_seed() +def test_configure_callbacks_hook_multiple_calls(tmpdir): + """ Test that subsequent calls to `configure_callbacks` do not change the callbacks list. """ + model_callback_mock = Mock() - model = EvalModelTemplate(tutils.get_default_hparams()) - model.configure_optimizers = model.configure_optimizers__single_scheduler + class TestModel(BoringModel): - lr_logger = LearningRateLogger() + def configure_callbacks(self): + return [model_callback_mock] + + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - max_epochs=5, - val_percent_check=0.1, - train_percent_check=0.5, - callbacks=[lr_logger] + fast_dev_run=True, + checkpoint_callback=False, + progress_bar_refresh_rate=1, ) - results = trainer.fit(model) - assert results == 1 - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ - 'Number of learning rates logged does not match number of lr schedulers' - assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ - 'Names of learning rates not set correctly' + callbacks_before_fit = trainer.callbacks.copy() + assert callbacks_before_fit + trainer.fit(model) + callbacks_after_fit = trainer.callbacks.copy() + assert callbacks_after_fit == callbacks_before_fit + [model_callback_mock] -def test_lr_logger_multi_lrs(tmpdir): - """ Test that learning rates are extracted and logged for multi lr schedulers """ - tutils.reset_seed() + for fn in ("test", "validate"): + trainer_fn = getattr(trainer, fn) + trainer_fn(model) - model = EvalModelTemplate(tutils.get_default_hparams()) - model.configure_optimizers = model.configure_optimizers__multiple_schedulers + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit - lr_logger = LearningRateLogger() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.5, - callbacks=[lr_logger] - ) - results = trainer.fit(model) - - assert results == 1 - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ - 'Number of learning rates logged does not match number of lr schedulers' - assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ - 'Names of learning rates not set correctly' + trainer_fn(ckpt_path=None) + callbacks_after = trainer.callbacks.copy() + assert callbacks_after == callbacks_after_fit diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py new file mode 100644 index 00000000000000..cc619077ee1368 --- /dev/null +++ b/tests/callbacks/test_early_stopping.py @@ -0,0 +1,385 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import pickle +from typing import List, Optional +from unittest import mock + +import cloudpickle +import numpy as np +import pytest +import torch + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + +_logger = logging.getLogger(__name__) + + +class EarlyStoppingTestRestore(EarlyStopping): + # this class has to be defined outside the test function, otherwise we get pickle error + def __init__(self, expected_state, *args, **kwargs): + super().__init__(*args, **kwargs) + self.expected_state = expected_state + # cache the state for each epoch + self.saved_states = [] + + def on_train_start(self, trainer, pl_module): + if self.expected_state: + assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy()) + + +def test_resume_early_stopping_from_checkpoint(tmpdir): + """ + Prevent regressions to bugs: + https://github.com/PyTorchLightning/pytorch-lightning/issues/1464 + https://github.com/PyTorchLightning/pytorch-lightning/issues/1463 + """ + seed_everything(42) + model = ClassificationModel() + dm = ClassifDataModule() + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) + early_stop_callback = EarlyStoppingTestRestore(None, monitor='train_loss') + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback, checkpoint_callback], + num_sanity_val_steps=0, + max_epochs=4, + ) + trainer.fit(model, datamodule=dm) + + checkpoint_filepath = checkpoint_callback.kth_best_model_path + # ensure state is persisted properly + checkpoint = torch.load(checkpoint_filepath) + # the checkpoint saves "epoch + 1" + early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] + assert 4 == len(early_stop_callback.saved_states) + assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state + + # ensure state is reloaded properly (assertion in the callback) + early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss') + new_trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + resume_from_checkpoint=checkpoint_filepath, + callbacks=[early_stop_callback], + ) + + with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'): + new_trainer.fit(model) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_early_stopping_no_extraneous_invocations(tmpdir): + """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" + model = ClassificationModel() + dm = ClassifDataModule() + early_stop_callback = EarlyStopping(monitor='train_loss') + expected_count = 4 + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + limit_train_batches=4, + limit_val_batches=4, + max_epochs=expected_count, + ) + trainer.fit(model, datamodule=dm) + + assert trainer.early_stopping_callback == early_stop_callback + assert trainer.early_stopping_callbacks == [early_stop_callback] + assert len(trainer.dev_debugger.early_stopping_history) == expected_count + + +@pytest.mark.parametrize( + "loss_values, patience, expected_stop_epoch", + [ + ([6, 5, 5, 5, 5, 5], 3, 4), + ([6, 5, 4, 4, 3, 3], 1, 3), + ([6, 5, 6, 5, 5, 5], 3, 4), + ], +) +def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expected_stop_epoch: int): + """Test to ensure that early stopping is not triggered before patience is exhausted.""" + + class ModelOverrideValidationReturn(BoringModel): + validation_return_values = torch.Tensor(loss_values) + + def validation_epoch_end(self, outputs): + loss = self.validation_return_values[self.current_epoch] + self.log("test_val_loss", loss) + + model = ModelOverrideValidationReturn() + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + val_check_interval=1.0, + num_sanity_val_steps=0, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + + +@pytest.mark.parametrize('validation_step_none', [True, False]) +@pytest.mark.parametrize( + "loss_values, patience, expected_stop_epoch", + [ + ([6, 5, 5, 5, 5, 5], 3, 4), + ([6, 5, 4, 4, 3, 3], 1, 3), + ([6, 5, 6, 5, 5, 5], 3, 4), + ], +) +def test_early_stopping_patience_train( + tmpdir, validation_step_none: bool, loss_values: list, patience: int, expected_stop_epoch: int +): + """Test to ensure that early stopping is not triggered before patience is exhausted.""" + + class ModelOverrideTrainReturn(BoringModel): + train_return_values = torch.Tensor(loss_values) + + def training_epoch_end(self, outputs): + loss = self.train_return_values[self.current_epoch] + self.log('train_loss', loss) + + model = ModelOverrideTrainReturn() + + if validation_step_none: + model.validation_step = None + + early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + num_sanity_val_steps=0, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + + +def test_pickling(tmpdir): + early_stopping = EarlyStopping() + + early_stopping_pickled = pickle.dumps(early_stopping) + early_stopping_loaded = pickle.loads(early_stopping_pickled) + assert vars(early_stopping) == vars(early_stopping_loaded) + + early_stopping_pickled = cloudpickle.dumps(early_stopping) + early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) + assert vars(early_stopping) == vars(early_stopping_loaded) + + +def test_early_stopping_no_val_step(tmpdir): + """Test that early stopping callback falls back to training metrics when no validation defined.""" + + model = ClassificationModel() + dm = ClassifDataModule() + model.validation_step = None + model.val_dataloader = None + + stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[stopping], + overfit_batches=0.20, + max_epochs=10, + ) + trainer.fit(model, datamodule=dm) + + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.current_epoch < trainer.max_epochs - 1 + + +def test_early_stopping_functionality(tmpdir): + + class CurrentModel(BoringModel): + + def validation_epoch_end(self, outputs): + losses = [8, 4, 2, 3, 4, 5, 8, 10] + val_loss = losses[self.current_epoch] + self.log('abc', val_loss) + + model = CurrentModel() + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[EarlyStopping(monitor='abc')], + overfit_batches=0.20, + max_epochs=20, + ) + trainer.fit(model) + assert trainer.current_epoch == 5, 'early_stopping failed' + + +@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) +def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): + """Excepted Behaviour: + IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop. + + IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` + when `early_stopping` is being triggered, + THEN the trainer should continue until reaching + `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. + This test validate this expected behaviour + + IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` + when `early_stopping` is being triggered, + THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. + + Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) + + This test validate those expected behaviours + """ + + _logger.disabled = True + + original_loss_value = 10 + limit_train_batches = 3 + patience = 3 + + class Model(BoringModel): + + def __init__(self, step_freeze): + super(Model, self).__init__() + + self._step_freeze = step_freeze + + self._loss_value = 10.0 + self._eps = 1e-1 + self._count_decrease = 0 + self._values = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + return {"test_val_loss": self._loss_value} + + def validation_epoch_end(self, outputs): + _mean = np.mean([x['test_val_loss'] for x in outputs]) + if self.trainer.global_step <= self._step_freeze: + self._count_decrease += 1 + self._loss_value -= self._eps + self._values.append(_mean) + self.log('test_val_loss', _mean) + + model = Model(step_freeze) + model.training_step_end = None + model.test_dataloader = None + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + limit_train_batches=limit_train_batches, + limit_val_batches=2, + min_steps=min_steps, + min_epochs=min_epochs + ) + trainer.fit(model) + + # Make sure loss was properly decreased + assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6 + + pos_diff = (np.diff(model._values) == 0).nonzero()[0][0] + + # Compute when the latest validation epoch end happened + latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches + if pos_diff % limit_train_batches == 0: + latest_validation_epoch_end += limit_train_batches + + # Compute early stopping latest step + by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience + + # Compute min_epochs latest step + by_min_epochs = min_epochs * limit_train_batches + + # Make sure the trainer stops for the max of all minimum requirements + assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ + (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) + + _logger.disabled = False + + +def test_early_stopping_mode_options(): + with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): + EarlyStopping(mode="unknown_option") + + +class EarlyStoppingModel(BoringModel): + + def __init__(self, expected_end_epoch): + super().__init__() + self.expected_end_epoch = expected_end_epoch + + def validation_epoch_end(self, outputs): + losses = [8, 4, 2, 3, 4, 5, 8, 10] + val_loss = losses[self.current_epoch] + self.log('abc', torch.tensor(val_loss)) + self.log('cba', torch.tensor(0)) + + def on_train_end(self) -> None: + assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' + + +@pytest.mark.parametrize( + "callbacks, expected_stop_epoch, accelerator, num_processes", + [ + ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1), + ([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], 3, None, 1), + pytest.param([EarlyStopping(monitor='abc'), + EarlyStopping(monitor='cba', patience=3)], + 3, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + pytest.param([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + ], +) +def test_multiple_early_stopping_callbacks( + tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int +): + """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" + + model = EarlyStoppingModel(expected_stop_epoch) + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=callbacks, + overfit_batches=0.20, + max_epochs=20, + accelerator=accelerator, + num_processes=num_processes + ) + trainer.fit(model) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py new file mode 100644 index 00000000000000..6aa32934ad771e --- /dev/null +++ b/tests/callbacks/test_finetuning_callback.py @@ -0,0 +1,246 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from torch import nn +from torch.optim import SGD +from torch.utils.data import DataLoader + +from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning +from pytorch_lightning.callbacks.base import Callback +from tests.helpers import BoringModel, RandomDataset + + +def test_finetuning_callback(tmpdir): + """Test finetuning callbacks works as expected""" + + seed_everything(42) + + class FinetuningBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.backbone = nn.Sequential(nn.Linear(32, 32, bias=False), nn.BatchNorm1d(32), nn.ReLU()) + self.layer = torch.nn.Linear(32, 2) + self.backbone.has_been_used = False + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def forward(self, x): + self.backbone.has_been_used = True + x = self.backbone(x) + return self.layer(x) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + class TestCallback(BackboneFinetuning): + + def on_train_epoch_end(self, trainer, pl_module, outputs): + epoch = trainer.current_epoch + if self.unfreeze_backbone_at_epoch <= epoch: + optimizer = trainer.optimizers[0] + current_lr = optimizer.param_groups[0]['lr'] + backbone_lr = self.previous_backbone_lr + if epoch < 6: + assert backbone_lr <= current_lr + else: + assert backbone_lr == current_lr + + model = FinetuningBoringModel() + callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False) + + trainer = Trainer( + limit_train_batches=1, + default_root_dir=tmpdir, + callbacks=[callback], + max_epochs=8, + ) + trainer.fit(model) + + assert model.backbone.has_been_used + + +def test_finetuning_callback_warning(tmpdir): + """Test finetuning callbacks works as expected""" + + seed_everything(42) + + class FinetuningBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.backbone = nn.Linear(32, 2, bias=False) + self.layer = None + self.backbone.has_been_used = False + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def forward(self, x): + self.backbone.has_been_used = True + x = self.backbone(x) + return x + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=0.1) + return optimizer + + class TestCallback(BackboneFinetuning): + + def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): + """Called when the epoch begins.""" + + if epoch == 0: + self.unfreeze_and_add_param_group( + pl_module.backbone, optimizer, 0.1, train_bn=self.train_bn, initial_denom_lr=self.initial_denom_lr + ) + + model = FinetuningBoringModel() + model.validation_step = None + callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False) + + with pytest.warns(UserWarning, match="Did you init your optimizer in"): + trainer = Trainer( + limit_train_batches=1, + default_root_dir=tmpdir, + callbacks=[callback], + max_epochs=2, + ) + trainer.fit(model) + + assert model.backbone.has_been_used + + +def test_freeze_unfreeze_function(tmpdir): + """Test freeze properly sets requires_grad on the modules""" + + seed_everything(42) + + class FreezeModel(LightningModule): + + def __init__(self): + super().__init__() + self.backbone = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 2)) + + model = FreezeModel() + BaseFinetuning.freeze(model, train_bn=True) + assert not model.backbone[0].weight.requires_grad + assert model.backbone[1].weight.requires_grad + assert not model.backbone[3].weight.requires_grad + + BaseFinetuning.freeze(model, train_bn=False) + assert not model.backbone[0].weight.requires_grad + assert not model.backbone[1].weight.requires_grad + assert not model.backbone[3].weight.requires_grad + + BaseFinetuning.make_trainable(model) + assert model.backbone[0].weight.requires_grad + assert model.backbone[1].weight.requires_grad + assert model.backbone[3].weight.requires_grad + + BaseFinetuning.freeze(model.backbone[0], train_bn=False) + assert not model.backbone[0].weight.requires_grad + + BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]), train_bn=True) + assert model.backbone[1].weight.requires_grad + assert not model.backbone[3].weight.requires_grad + + +def test_unfreeze_and_add_param_group_function(tmpdir): + """Test unfreeze_and_add_param_group properly unfreeze parameters and add to the correct param_group""" + + seed_everything(42) + + class FreezeModel(LightningModule): + + def __init__(self): + super().__init__() + self.backbone = nn.Sequential( + nn.Linear(32, 32, bias=False), + nn.Linear(32, 32, bias=False), + nn.Linear(32, 32, bias=False), + nn.Linear(32, 32, bias=False), + nn.Linear(32, 32, bias=False), + nn.BatchNorm1d(32), + ) + + model = FreezeModel() + optimizer = SGD(model.backbone[0].parameters(), lr=0.01) + + with pytest.warns(UserWarning, match="The provided params to be freezed already"): + BaseFinetuning.unfreeze_and_add_param_group(model.backbone[0], optimizer=optimizer) + assert optimizer.param_groups[0]["lr"] == 0.01 + + model.backbone[1].weight.requires_grad = False + BaseFinetuning.unfreeze_and_add_param_group(model.backbone[1], optimizer=optimizer) + assert len(optimizer.param_groups) == 2 + assert optimizer.param_groups[1]["lr"] == 0.001 + assert torch.equal(optimizer.param_groups[1]["params"][0], model.backbone[1].weight) + assert model.backbone[1].weight.requires_grad + + with pytest.warns(UserWarning, match="The provided params to be freezed already"): + BaseFinetuning.unfreeze_and_add_param_group(model, optimizer=optimizer, lr=100, train_bn=False) + assert len(optimizer.param_groups) == 3 + assert optimizer.param_groups[2]["lr"] == 100 + assert len(optimizer.param_groups[2]["params"]) == 3 + for group_idx, group in enumerate(optimizer.param_groups): + if group_idx == 0: + assert torch.equal(optimizer.param_groups[0]["params"][0], model.backbone[0].weight) + if group_idx == 2: + assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight) + assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight) + assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight) + + +def test_on_before_accelerator_backend_setup(tmpdir): + """ + `on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before + before configure_optimizers function call. + """ + + class TestCallback(Callback): + + def on_before_accelerator_backend_setup(self, trainer, pl_module): + pl_module.on_before_accelerator_backend_setup_called = True + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_before_accelerator_backend_setup_called = False + + def configure_optimizers(self): + assert self.on_before_accelerator_backend_setup_called + return super().configure_optimizers() + + model = TestModel() + callback = TestCallback() + + trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True) + trainer.fit(model) diff --git a/tests/callbacks/test_gpu_stats_monitor.py b/tests/callbacks/test_gpu_stats_monitor.py new file mode 100644 index 00000000000000..c2c4c87c284b04 --- /dev/null +++ b/tests/callbacks/test_gpu_stats_monitor.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import GPUStatsMonitor +from pytorch_lightning.loggers import CSVLogger +from pytorch_lightning.loggers.csv_logs import ExperimentWriter +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@RunIf(min_gpus=1) +def test_gpu_stats_monitor(tmpdir): + """ + Test GPU stats are logged using a logger. + """ + model = BoringModel() + gpu_stats = GPUStatsMonitor(intra_step_time=True) + logger = CSVLogger(tmpdir) + log_every_n_steps = 2 + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=7, + log_every_n_steps=log_every_n_steps, + gpus=1, + callbacks=[gpu_stats], + logger=logger + ) + + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) + met_data = np.genfromtxt(path_csv, delimiter=',', names=True, deletechars='', replace_space=' ') + + batch_time_data = met_data['batch_time/intra_step (ms)'] + batch_time_data = batch_time_data[~np.isnan(batch_time_data)] + assert batch_time_data.shape[0] == trainer.global_step // log_every_n_steps + + fields = [ + 'utilization.gpu', + 'memory.used', + 'memory.free', + 'utilization.memory', + ] + + for f in fields: + assert any([f in h for h in met_data.dtype.names]) + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="test requires CPU machine") +def test_gpu_stats_monitor_cpu_machine(tmpdir): + """ + Test GPUStatsMonitor on CPU machine. + """ + with pytest.raises(MisconfigurationException, match='NVIDIA driver is not installed'): + GPUStatsMonitor() + + +@RunIf(min_gpus=1) +def test_gpu_stats_monitor_no_logger(tmpdir): + """ + Test GPUStatsMonitor with no logger in Trainer. + """ + model = BoringModel() + gpu_stats = GPUStatsMonitor() + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[gpu_stats], + max_epochs=1, + gpus=1, + logger=False, + ) + + with pytest.raises(MisconfigurationException, match='Trainer that has no logger.'): + trainer.fit(model) + + +@RunIf(min_gpus=1) +def test_gpu_stats_monitor_no_gpu_warning(tmpdir): + """ + Test GPUStatsMonitor raises a warning when not training on GPU device. + """ + model = BoringModel() + gpu_stats = GPUStatsMonitor() + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[gpu_stats], + max_steps=1, + gpus=None, + ) + + with pytest.raises(MisconfigurationException, match='not running on GPU'): + trainer.fit(model) + + +def test_gpu_stats_monitor_parse_gpu_stats(): + logs = GPUStatsMonitor._parse_gpu_stats('1,2', [[3, 4, 5], [6, 7]], [('gpu', 'a'), ('memory', 'b')]) + expected = {'gpu_id: 1/gpu (a)': 3, 'gpu_id: 1/memory (b)': 4, 'gpu_id: 2/gpu (a)': 6, 'gpu_id: 2/memory (b)': 7} + assert logs == expected diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py new file mode 100644 index 00000000000000..c2edfb176f1649 --- /dev/null +++ b/tests/callbacks/test_lambda_function.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import Callback, LambdaCallback +from tests.helpers.boring_model import BoringModel + + +def test_lambda_call(tmpdir): + seed_everything(42) + + class CustomModel(BoringModel): + + def on_train_epoch_start(self): + if self.current_epoch > 1: + raise KeyboardInterrupt + + checker = set() + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] + hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} + hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint") + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[LambdaCallback(**hooks_args)], + ) + results = trainer.fit(model) + assert results + + model = CustomModel() + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + resume_from_checkpoint=ckpt_path, + callbacks=[LambdaCallback(**hooks_args)], + ) + results = trainer.fit(model) + trainer.test(model) + + assert results + assert checker == set(hooks) diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py new file mode 100644 index 00000000000000..3018055e0b7a06 --- /dev/null +++ b/tests/callbacks/test_lr_monitor.py @@ -0,0 +1,281 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch import optim + +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.simple_models import ClassificationModel + + +def test_lr_monitor_single_lr(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler. """ + tutils.reset_seed() + + model = BoringModel() + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=0.1, + limit_train_batches=0.5, + callbacks=[lr_monitor], + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert lr_monitor.lrs, 'No learning rates logged' + assert all(v is None for v in lr_monitor.last_momentum_values.values()), \ + 'Momentum should not be logged by default' + assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['lr-SGD'], \ + 'Names of learning rates not set correctly' + + +@pytest.mark.parametrize('opt', ['SGD', 'Adam']) +def test_lr_monitor_single_lr_with_momentum(tmpdir, opt: str): + """Test that learning rates and momentum are extracted and logged for single lr scheduler.""" + + class LogMomentumModel(BoringModel): + + def __init__(self, opt): + super().__init__() + self.opt = opt + + def configure_optimizers(self): + if self.opt == 'SGD': + opt_kwargs = {'momentum': 0.9} + elif self.opt == 'Adam': + opt_kwargs = {'betas': (0.9, 0.999)} + + optimizer = getattr(optim, self.opt)(self.parameters(), lr=1e-2, **opt_kwargs) + lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, total_steps=10_000) + return [optimizer], [lr_scheduler] + + model = LogMomentumModel(opt=opt) + lr_monitor = LearningRateMonitor(log_momentum=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=2, + limit_train_batches=5, + log_every_n_steps=1, + callbacks=[lr_monitor], + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert all(v is not None for v in lr_monitor.last_momentum_values.values()), \ + 'Expected momentum to be logged' + assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers), \ + 'Number of momentum values logged does not match number of lr schedulers' + assert all(k == f'lr-{opt}-momentum' for k in lr_monitor.last_momentum_values.keys()), \ + 'Names of momentum values not set correctly' + + +def test_log_momentum_no_momentum_optimizer(tmpdir): + """ + Test that if optimizer doesn't have momentum then a warning is raised with log_momentum=True. + """ + + class LogMomentumModel(BoringModel): + + def configure_optimizers(self): + optimizer = optim.ASGD(self.parameters(), lr=1e-2) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + model = LogMomentumModel() + lr_monitor = LearningRateMonitor(log_momentum=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=2, + limit_train_batches=5, + log_every_n_steps=1, + callbacks=[lr_monitor], + ) + with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."): + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), \ + 'Expected momentum to be logged' + assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers), \ + 'Number of momentum values logged does not match number of lr schedulers' + assert all(k == 'lr-ASGD-momentum' for k in lr_monitor.last_momentum_values.keys()), \ + 'Names of momentum values not set correctly' + + +def test_lr_monitor_no_lr_scheduler(tmpdir): + tutils.reset_seed() + + class CustomBoringModel(BoringModel): + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=0.1) + return optimizer + + model = CustomBoringModel() + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=0.1, + limit_train_batches=0.5, + callbacks=[lr_monitor], + ) + + with pytest.warns(RuntimeWarning, match='have no learning rate schedulers'): + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +def test_lr_monitor_no_logger(tmpdir): + tutils.reset_seed() + + model = BoringModel() + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + callbacks=[lr_monitor], + logger=False, + ) + + with pytest.raises(MisconfigurationException, match='`Trainer` that has no logger'): + trainer.fit(model) + + +@pytest.mark.parametrize("logging_interval", ['step', 'epoch']) +def test_lr_monitor_multi_lrs(tmpdir, logging_interval: str): + """ Test that learning rates are extracted and logged for multi lr schedulers. """ + tutils.reset_seed() + + class CustomBoringModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=1e-2) + optimizer2 = optim.Adam(self.parameters(), lr=1e-2) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + model = CustomBoringModel() + model.training_epoch_end = None + + lr_monitor = LearningRateMonitor(logging_interval=logging_interval) + log_every_n_steps = 2 + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + log_every_n_steps=log_every_n_steps, + limit_train_batches=7, + limit_val_batches=0.1, + callbacks=[lr_monitor], + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert lr_monitor.lrs, 'No learning rates logged' + assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert lr_monitor.lr_sch_names == ['lr-Adam', 'lr-Adam-1'], \ + 'Names of learning rates not set correctly' + + if logging_interval == 'step': + expected_number_logged = trainer.global_step // log_every_n_steps + if logging_interval == 'epoch': + expected_number_logged = trainer.max_epochs + + assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()), \ + 'Length of logged learning rates do not match the expected number' + + +def test_lr_monitor_param_groups(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler. """ + tutils.reset_seed() + + class CustomClassificationModel(ClassificationModel): + + def configure_optimizers(self): + param_groups = [{ + 'params': list(self.parameters())[:2], + 'lr': self.lr * 0.1 + }, { + 'params': list(self.parameters())[2:], + 'lr': self.lr + }] + + optimizer = optim.Adam(param_groups) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) + return [optimizer], [lr_scheduler] + + model = CustomClassificationModel() + dm = ClassifDataModule() + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=0.1, + limit_train_batches=0.5, + callbacks=[lr_monitor], + ) + trainer.fit(model, datamodule=dm) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert lr_monitor.lrs, 'No learning rates logged' + assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of param groups' + assert lr_monitor.lr_sch_names == ['lr-Adam'] + assert list(lr_monitor.lrs.keys()) == ['lr-Adam/pg1', 'lr-Adam/pg2'], \ + 'Names of learning rates not set correctly' + + +def test_lr_monitor_custom_name(tmpdir): + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer, [scheduler] = super().configure_optimizers() + lr_scheduler = {'scheduler': scheduler, 'name': 'my_logging_name'} + return optimizer, [lr_scheduler] + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_val_batches=0.1, + limit_train_batches=0.5, + callbacks=[lr_monitor], + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(TestModel()) + assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['my_logging_name'] diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index ebd35fedfa13d6..76f1e4cb0570f7 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -1,27 +1,51 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from typing import Optional, Union +from unittest import mock +from unittest.mock import ANY, call, Mock + import pytest +import torch -import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.progress import tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate - - -@pytest.mark.parametrize('callbacks,refresh_rate', [ - ([], 1), - ([], 2), - ([ProgressBar(refresh_rate=1)], 0), - ([ProgressBar(refresh_rate=2)], 0), - ([ProgressBar(refresh_rate=2)], 1), -]) -def test_progress_bar_on(callbacks, refresh_rate): +from tests.helpers import BoringModel + + +@pytest.mark.parametrize( + 'callbacks,refresh_rate', [ + ([], None), + ([], 1), + ([], 2), + ([ProgressBar(refresh_rate=1)], 0), + ([ProgressBar(refresh_rate=2)], 0), + ([ProgressBar(refresh_rate=2)], 1), + ] +) +def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): """Test different ways the progress bar can be turned on.""" trainer = Trainer( + default_root_dir=tmpdir, callbacks=callbacks, progress_bar_refresh_rate=refresh_rate, max_epochs=1, - overfit_pct=0.2, + overfit_batches=5, ) progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] @@ -30,15 +54,18 @@ def test_progress_bar_on(callbacks, refresh_rate): assert progress_bars[0] is trainer.progress_bar_callback -@pytest.mark.parametrize('callbacks,refresh_rate', [ - ([], 0), - ([], False), - ([ModelCheckpoint('../trainer')], 0), -]) -def test_progress_bar_off(callbacks, refresh_rate): +@pytest.mark.parametrize( + 'callbacks,refresh_rate', [ + ([], 0), + ([], False), + ([ModelCheckpoint(dirpath='../trainer')], 0), + ] +) +def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int]): """Test different ways the progress bar can be turned off.""" trainer = Trainer( + default_root_dir=tmpdir, callbacks=callbacks, progress_bar_refresh_rate=refresh_rate, ) @@ -50,19 +77,19 @@ def test_progress_bar_off(callbacks, refresh_rate): def test_progress_bar_misconfiguration(): """Test that Trainer doesn't accept multiple progress bars.""" - callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint('../trainer')] + callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint(dirpath='../trainer')] with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): Trainer(callbacks=callbacks) -def test_progress_bar_totals(): +def test_progress_bar_totals(tmpdir): """Test that the progress finishes with the correct total steps processed.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = BoringModel() trainer = Trainer( + default_root_dir=tmpdir, progress_bar_refresh_rate=1, - val_percent_check=1.0, max_epochs=1, ) bar = trainer.progress_bar_callback @@ -94,6 +121,12 @@ def test_progress_bar_totals(): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + assert bar.val_progress_bar.total == m + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + trainer.test(model) # check test progress bar total @@ -106,19 +139,20 @@ def test_progress_bar_totals(): assert bar.test_batch_idx == k -def test_progress_bar_fast_dev_run(): - model = EvalModelTemplate(tutils.get_default_hparams()) +def test_progress_bar_fast_dev_run(tmpdir): + model = BoringModel() trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, ) + trainer.fit(model) + progress_bar = trainer.progress_bar_callback assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded - trainer.fit(model) - assert 1 == progress_bar.total_val_batches assert 1 == progress_bar.train_batch_idx assert 1 == progress_bar.val_batch_idx @@ -128,6 +162,13 @@ def test_progress_bar_fast_dev_run(): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -137,10 +178,10 @@ def test_progress_bar_fast_dev_run(): @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) -def test_progress_bar_progress_refresh(refresh_rate): +def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int): """Test that the three progress bars get correctly updated when using different refresh rates.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = BoringModel() class CurrentProgressBar(ProgressBar): @@ -148,42 +189,296 @@ class CurrentProgressBar(ProgressBar): val_batches_seen = 0 test_batches_seen = 0 - def on_batch_start(self, trainer, pl_module): - super().on_batch_start(trainer, pl_module) + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) assert self.train_batch_idx == trainer.batch_idx - def on_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) assert self.train_batch_idx == trainer.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx self.train_batches_seen += 1 - def on_validation_batch_end(self, trainer, pl_module): - super().on_validation_batch_end(trainer, pl_module) + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if not self.is_disabled and self.val_batch_idx % self.refresh_rate == 0: assert self.val_progress_bar.n == self.val_batch_idx self.val_batches_seen += 1 - def on_test_batch_end(self, trainer, pl_module): - super().on_test_batch_end(trainer, pl_module) + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if not self.is_disabled and self.test_batch_idx % self.refresh_rate == 0: assert self.test_progress_bar.n == self.test_batch_idx self.test_batches_seen += 1 progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) trainer = Trainer( + default_root_dir=tmpdir, callbacks=[progress_bar], progress_bar_refresh_rate=101, # should not matter if custom callback provided - train_percent_check=1.0, + limit_train_batches=1.0, num_sanity_val_steps=2, max_epochs=3, ) - assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate + assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + +@pytest.mark.parametrize('limit_val_batches', (0, 5)) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): + """ + Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. + """ + + class CurrentProgressBar(ProgressBar): + val_pbar_total = 0 + sanity_pbar_total = 0 + + def on_sanity_check_end(self, *args): + self.sanity_pbar_total = self.val_progress_bar.total + super().on_sanity_check_end(*args) + + def on_validation_epoch_end(self, *args): + self.val_pbar_total = self.val_progress_bar.total + super().on_validation_epoch_end(*args) + + model = BoringModel() + progress_bar = CurrentProgressBar() + num_sanity_val_steps = 2 + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + num_sanity_val_steps=num_sanity_val_steps, + limit_train_batches=1, + limit_val_batches=limit_val_batches, + callbacks=[progress_bar], + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + + assert progress_bar.sanity_pbar_total == min(num_sanity_val_steps, limit_val_batches) + assert progress_bar.val_pbar_total == limit_val_batches + + +def test_progress_bar_default_value(tmpdir): + """ Test that a value of None defaults to refresh rate 1. """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.progress_bar_callback.refresh_rate == 1 + + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) + assert trainer.progress_bar_callback.refresh_rate == 1 + + +@mock.patch.dict(os.environ, {'COLAB_GPU': '1'}) +def test_progress_bar_value_on_colab(tmpdir): + """ Test that Trainer will override the default in Google COLAB. """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.progress_bar_callback.refresh_rate == 20 + + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) + assert trainer.progress_bar_callback.refresh_rate == 20 + + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) + assert trainer.progress_bar_callback.refresh_rate == 19 + + +class MockedUpdateProgressBars(ProgressBar): + """ Mocks the update method once bars get initializied. """ + + def _mock_bar_update(self, bar): + bar.update = Mock(wraps=bar.update) + return bar + + def init_train_tqdm(self): + bar = super().init_train_tqdm() + return self._mock_bar_update(bar) + + def init_validation_tqdm(self): + bar = super().init_validation_tqdm() + return self._mock_bar_update(bar) + + def init_test_tqdm(self): + bar = super().init_test_tqdm() + return self._mock_bar_update(bar) + + +@pytest.mark.parametrize( + "train_batches,val_batches,refresh_rate,train_deltas,val_deltas", [ + [2, 3, 1, [1, 1, 1, 1, 1], [1, 1, 1]], + [0, 0, 3, [], []], + [1, 0, 3, [1], []], + [1, 1, 3, [2], [1]], + [5, 0, 3, [3, 2], []], + [5, 2, 3, [3, 3, 1], [2]], + [5, 2, 6, [6, 1], [2]], + ] +) +def test_main_progress_bar_update_amount( + tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_deltas: list, val_deltas: list +): + """ + Test that the main progress updates with the correct amount together with the val progress. At the end of + the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. + """ + model = BoringModel() + progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=train_batches, + limit_val_batches=val_batches, + callbacks=[progress_bar], + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas]) + if val_batches > 0: + progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas]) + + +@pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [ + [1, 3, [1]], + [3, 1, [1, 1, 1]], + [5, 3, [3, 2]], +]) +def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, test_deltas: list): + """ + Test that test progress updates with the correct amount. + """ + model = BoringModel() + progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=test_batches, + callbacks=[progress_bar], + logger=False, + checkpoint_callback=False, + ) + trainer.test(model) + progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas]) + + +def test_tensor_to_float_conversion(tmpdir): + """Check tensor gets converted to float""" + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log('foo', torch.tensor(0.123), prog_bar=True) + self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True) + return super().training_step(batch, batch_idx) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(TestModel()) + + pbar = trainer.progress_bar_callback.main_progress_bar + actual = str(pbar.postfix) + assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + + +@pytest.mark.parametrize( + "input_num, expected", [ + [1, '1'], + [1.0, '1.000'], + [0.1, '0.100'], + [1e-3, '0.001'], + [1e-5, '1e-5'], + ['1.0', '1.000'], + ['10000', '10000'], + ['abc', 'abc'], + ] +) +def test_tqdm_format_num(input_num: Union[str, int, float], expected: str): + """ Check that the specialized tqdm.format_num appends 0 to floats and strings """ + assert tqdm.format_num(input_num) == expected + + +class PrintModel(BoringModel): + + def training_step(self, *args, **kwargs): + self.print("training_step", end="") + return super().training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + self.print("validation_step", file=sys.stderr) + return super().validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + self.print("test_step") + return super().test_step(*args, **kwargs) + + +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print(tqdm_write, tmpdir): + """ Test that printing in the LightningModule redirects arguments to the progress bar. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1, + callbacks=[bar], + ) + trainer.fit(model) + trainer.test(model) + assert tqdm_write.call_count == 3 + assert tqdm_write.call_args_list == [ + call("training_step", end="", file=None, nolock=False), + call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), + call("test_step", end=os.linesep, file=None, nolock=False), + ] + + +@mock.patch('builtins.print') +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): + """ Test that printing in LightningModule goes through built-in print function when progress bar is disabled. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1, + callbacks=[bar], + ) + bar.disable() + trainer.fit(model) + trainer.test(model) + + mock_print.assert_has_calls([ + call("training_step", end=""), + call("validation_step", file=ANY), + call("test_step"), + ]) + tqdm_write.assert_not_called() diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py new file mode 100644 index 00000000000000..e42689a25d8aa3 --- /dev/null +++ b/tests/callbacks/test_pruning.py @@ -0,0 +1,311 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from logging import INFO +from typing import Union + +import pytest +import torch +import torch.nn.utils.prune as pytorch_prune +from torch import nn +from torch.nn import Sequential + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class TestModel(BoringModel): + test_step = None + + def __init__(self): + super().__init__() + self.layer = Sequential( + OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 32)), + ("mlp_3", nn.Linear(32, 2)), + ]) + ) + + def training_step(self, batch, batch_idx): + self.log("test", -batch_idx) + return super().training_step(batch, batch_idx) + + +class TestPruningMethod(pytorch_prune.BasePruningMethod): + PRUNING_TYPE = "unstructured" + + def compute_mask(self, _, default_mask): + mask = default_mask.clone() + # Prune every other entry in a tensor + mask.view(-1)[::2] = 0 + return mask + + @classmethod + def apply(cls, module, name, amount): + return super(TestPruningMethod, cls).apply(module, name, amount=amount) + + +def train_with_pruning_callback( + tmpdir, + parameters_to_prune=False, + use_global_unstructured=False, + pruning_fn="l1_unstructured", + use_lottery_ticket_hypothesis=False, + accelerator=None, + gpus=None, + num_processes=1, +): + model = TestModel() + + # Weights are random. None is 0 + assert torch.all(model.layer.mlp_2.weight != 0) + + pruning_kwargs = { + "pruning_fn": pruning_fn, + "amount": 0.3, + "use_global_unstructured": use_global_unstructured, + "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis, + "verbose": 1, + } + if parameters_to_prune: + pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] + else: + pruning_kwargs["parameter_names"] = ["weight"] + if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): + pruning_kwargs["pruning_dim"] = 0 + if pruning_fn == "ln_structured": + pruning_kwargs["pruning_norm"] = 1 + + # Misconfiguration checks + if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured") and use_global_unstructured: + with pytest.raises(MisconfigurationException, match="is supported with `use_global_unstructured=True`"): + ModelPruning(**pruning_kwargs) + return + if ModelPruning._is_pruning_method(pruning_fn) and not use_global_unstructured: + with pytest.raises(MisconfigurationException, match="currently only supported with"): + ModelPruning(**pruning_kwargs) + return + + pruning = ModelPruning(**pruning_kwargs) + + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + weights_summary=None, + checkpoint_callback=False, + logger=False, + limit_train_batches=10, + limit_val_batches=2, + max_epochs=10, + accelerator=accelerator, + gpus=gpus, + num_processes=num_processes, + callbacks=pruning, + ) + trainer.fit(model) + trainer.test(model) + + if not accelerator: + # Check some have been pruned + assert torch.any(model.layer.mlp_2.weight == 0) + + +def test_pruning_misconfiguration(): + with pytest.raises(MisconfigurationException, match=r"chocolate isn't in \('weight', 'bias'\)"): + ModelPruning(pruning_fn="l1_unstructured", parameter_names=["chocolate"]) + with pytest.raises(MisconfigurationException, match=r"expected to be a str in \["): + ModelPruning(pruning_fn={}) # noqa + with pytest.raises(MisconfigurationException, match="should be provided"): + ModelPruning(pruning_fn="random_structured") + with pytest.raises(MisconfigurationException, match=r"must be any of \(0, 1, 2\)"): + ModelPruning(pruning_fn="l1_unstructured", verbose=3) + with pytest.raises(MisconfigurationException, match="requesting `ln_structured` pruning, the `pruning_norm`"): + ModelPruning(pruning_fn="ln_structured", pruning_dim=0) + + +@pytest.mark.parametrize("parameters_to_prune", [False, True]) +@pytest.mark.parametrize("use_global_unstructured", [False, True]) +@pytest.mark.parametrize( + "pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod] +) +@pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True]) +def test_pruning_callback( + tmpdir, use_global_unstructured: bool, parameters_to_prune: bool, + pruning_fn: Union[str, pytorch_prune.BasePruningMethod], use_lottery_ticket_hypothesis: bool +): + train_with_pruning_callback( + tmpdir, + parameters_to_prune=parameters_to_prune, + use_global_unstructured=use_global_unstructured, + pruning_fn=pruning_fn, + use_lottery_ticket_hypothesis=use_lottery_ticket_hypothesis, + ) + + +@RunIf(special=True) +@pytest.mark.parametrize("parameters_to_prune", [False, True]) +@pytest.mark.parametrize("use_global_unstructured", [False, True]) +def test_pruning_callback_ddp(tmpdir, use_global_unstructured: bool, parameters_to_prune: bool): + train_with_pruning_callback( + tmpdir, + parameters_to_prune=parameters_to_prune, + use_global_unstructured=use_global_unstructured, + accelerator="ddp", + gpus=2, + ) + + +@RunIf(min_gpus=2, skip_windows=True) +def test_pruning_callback_ddp_spawn(tmpdir): + train_with_pruning_callback(tmpdir, use_global_unstructured=True, accelerator="ddp_spawn", gpus=2) + + +@RunIf(skip_windows=True) +def test_pruning_callback_ddp_cpu(tmpdir): + train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2) + + +@pytest.mark.parametrize("resample_parameters", (False, True)) +def test_pruning_lth_callable(tmpdir, resample_parameters: bool): + model = TestModel() + + class ModelPruningTestCallback(ModelPruning): + lth_calls = 0 + + def apply_lottery_ticket_hypothesis(self): + super().apply_lottery_ticket_hypothesis() + self.lth_calls += 1 + + for d in self._original_layers.values(): + copy, names = d["data"], d["names"] + for i, name in names: + curr, curr_name = self._parameters_to_prune[i] + assert name == curr_name + actual, expected = getattr(curr, name).data, getattr(copy, name).data + allclose = torch.allclose(actual, expected) + assert not allclose if self._resample_parameters else allclose + + pruning = ModelPruningTestCallback( + "l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2), resample_parameters=resample_parameters + ) + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + weights_summary=None, + checkpoint_callback=False, + logger=False, + limit_train_batches=10, + limit_val_batches=2, + max_epochs=5, + callbacks=pruning, + ) + trainer.fit(model) + + assert pruning.lth_calls == trainer.max_epochs // 2 + + +@pytest.mark.parametrize("make_pruning_permanent", (False, True)) +def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool): + seed_everything(0) + model = TestModel() + pruning_kwargs = { + 'parameters_to_prune': [(model.layer.mlp_1, "weight"), (model.layer.mlp_3, "weight")], + 'verbose': 2, + "make_pruning_permanent": make_pruning_permanent + } + p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs) + p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs) + + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + weights_summary=None, + checkpoint_callback=False, + logger=False, + limit_train_batches=10, + limit_val_batches=2, + max_epochs=3, + callbacks=[p1, p2], + ) + with caplog.at_level(INFO): + trainer.fit(model) + + actual = [m.strip() for m in caplog.messages] + actual = [m for m in actual if m.startswith("Applied")] + assert actual == [ + "Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)", + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501 + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501 + "Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)", + "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 506 (49.41%) -> 633 (61.82%)", # noqa: E501 + "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 38 (59.38%) -> 47 (73.44%)", # noqa: E501 + "Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)", + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501 + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501 + ] + + filepath = str(tmpdir / "foo.ckpt") + trainer.save_checkpoint(filepath) + + model.load_from_checkpoint(filepath, strict=False) + has_pruning = hasattr(model.layer.mlp_1, "weight_orig") + assert not has_pruning if make_pruning_permanent else has_pruning + + +def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): + """ + When a model is saved multiple times and make_permanent=True, we need to + make sure a copy is pruned and not the trained model if we want to continue + with the same pruning buffers. + """ + seed_everything(0) + + class TestPruning(ModelPruning): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + super().on_save_checkpoint(trainer, pl_module, checkpoint) + assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] + assert hasattr(pl_module.layer.mlp_3, "weight_orig") + + model = TestModel() + pruning_callback = TestPruning( + "random_unstructured", + parameters_to_prune=[(model.layer.mlp_3, "weight")], + verbose=1, + make_pruning_permanent=True + ) + ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True) + trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0) + with caplog.at_level(INFO): + trainer.fit(model) + + actual = [m.strip() for m in caplog.messages] + actual = [m for m in actual if m.startswith("Applied")] + assert actual == [ + "Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)", + "Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)", + "Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)", + ] + + # removed on_train_end + assert not hasattr(model.layer.mlp_3, "weight_orig") + + model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path) + assert not hasattr(model.layer.mlp_3, "weight_orig") + model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) + assert not hasattr(model.layer.mlp_3, "weight_orig") diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py new file mode 100644 index 00000000000000..3d9c44d1879967 --- /dev/null +++ b/tests/callbacks/test_quantization.py @@ -0,0 +1,140 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import Callable, Union + +import pytest +import torch + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import QuantizationAwareTraining +from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.datamodules import RegressDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import RegressionModel + + +@pytest.mark.parametrize("observe", ['average', pytest.param('histogram', marks=RunIf(min_torch="1.5"))]) +@pytest.mark.parametrize("fuse", [True, False]) +@RunIf(quantization=True) +def test_quantization(tmpdir, observe: str, fuse: bool): + """Parity test for quant model""" + seed_everything(42) + dm = RegressDataModule() + trainer_args = dict( + default_root_dir=tmpdir, + max_epochs=10, + gpus=1 if torch.cuda.is_available() else None, + ) + model = RegressionModel() + qmodel = copy.deepcopy(model) + + trainer = Trainer(**trainer_args) + trainer.fit(model, datamodule=dm) + org_size = model.model_size + org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()])) + + fusing_layers = [(f'layer_{i}', f'layer_{i}a') for i in range(3)] if fuse else None + qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers) + trainer = Trainer(callbacks=[qcb], **trainer_args) + trainer.fit(qmodel, datamodule=dm) + + quant_calls = qcb._forward_calls + assert quant_calls == qcb._forward_calls + + quant_size = qmodel.model_size + quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()])) + # test that the trained model is smaller then initial + size_ratio = quant_size / org_size + assert size_ratio < 0.65 + # test that the test score is almost the same as with pure training + assert torch.allclose(org_score, quant_score, atol=0.45) + + +@RunIf(quantization=True) +def test_quantize_torchscript(tmpdir): + """Test converting to torchscipt """ + dm = RegressDataModule() + qmodel = RegressionModel() + qcb = QuantizationAwareTraining(input_compatible=False) + trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1) + trainer.fit(qmodel, datamodule=dm) + + batch = iter(dm.test_dataloader()).next() + qmodel(qmodel.quant(batch[0])) + + tsmodel = qmodel.to_torchscript() + tsmodel(tsmodel.quant(batch[0])) + + +@RunIf(quantization=True) +def test_quantization_exceptions(tmpdir): + """Test wrong fuse layers""" + with pytest.raises(MisconfigurationException, match='Unsupported qconfig'): + QuantizationAwareTraining(qconfig=['abc']) + + with pytest.raises(MisconfigurationException, match='Unsupported observer type'): + QuantizationAwareTraining(observer_type='abc') + + with pytest.raises(MisconfigurationException, match='Unsupported `collect_quantization`'): + QuantizationAwareTraining(collect_quantization='abc') + + with pytest.raises(MisconfigurationException, match='Unsupported `collect_quantization`'): + QuantizationAwareTraining(collect_quantization=1.2) + + fusing_layers = [(f'layers.mlp_{i}', f'layers.NONE-mlp_{i}a') for i in range(3)] + qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers) + trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1) + with pytest.raises(MisconfigurationException, match='one or more of them is not your model attributes'): + trainer.fit(RegressionModel(), datamodule=RegressDataModule()) + + +def custom_trigger_never(trainer): + return False + + +def custom_trigger_even(trainer): + return trainer.current_epoch % 2 == 0 + + +def custom_trigger_last(trainer): + return trainer.current_epoch == (trainer.max_epochs - 1) + + +@pytest.mark.parametrize( + "trigger_fn,expected_count", [ + (None, 9), + (3, 3), + (custom_trigger_never, 0), + (custom_trigger_even, 5), + (custom_trigger_last, 2), + ] +) +@RunIf(quantization=True) +def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], expected_count: int): + """Test how many times the quant is called""" + dm = RegressDataModule() + qmodel = RegressionModel() + qcb = QuantizationAwareTraining(collect_quantization=trigger_fn) + trainer = Trainer( + callbacks=[qcb], + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=4, + ) + trainer.fit(qmodel, datamodule=dm) + + assert qcb._forward_calls == expected_count diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py new file mode 100644 index 00000000000000..12121b1f38530d --- /dev/null +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -0,0 +1,181 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest +import torch +from torch import nn +from torch.utils.data import DataLoader + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf + +if _TORCH_GREATER_EQUAL_1_6: + from pytorch_lightning.callbacks import StochasticWeightAveraging + + class SwaTestModel(BoringModel): + + def __init__(self, batchnorm: bool = True): + super().__init__() + layers = [nn.Linear(32, 32)] + if batchnorm: + layers.append(nn.BatchNorm1d(32)) + layers += [nn.ReLU(), nn.Linear(32, 2)] + self.layer = nn.Sequential(*layers) + + def training_step(self, batch, batch_idx): + output = self.forward(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + class SwaTestCallback(StochasticWeightAveraging): + update_parameters_calls: int = 0 + transfer_weights_calls: int = 0 + + def update_parameters(self, *args, **kwargs): + self.update_parameters_calls += 1 + return StochasticWeightAveraging.update_parameters(*args, **kwargs) + + def transfer_weights(self, *args, **kwargs): + self.transfer_weights_calls += 1 + return StochasticWeightAveraging.transfer_weights(*args, **kwargs) + + def on_train_epoch_start(self, trainer, *args): + super().on_train_epoch_start(trainer, *args) + assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end) + + def on_train_epoch_end(self, trainer, *args): + super().on_train_epoch_end(trainer, *args) + if self.swa_start <= trainer.current_epoch <= self.swa_end: + swa_epoch = trainer.current_epoch - self.swa_start + assert self.n_averaged == swa_epoch + 1 + elif trainer.current_epoch > self.swa_end: + assert self.n_averaged == self._max_epochs - self.swa_start + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + + # make sure these are correctly set again + assert not trainer.train_loop._skip_backward + assert trainer.accumulate_grad_batches == 2 + assert trainer.num_training_batches == 5 + + # check backward call count. the batchnorm update epoch should not backward + assert trainer.dev_debugger.count_events( + "backward_call" + ) == trainer.max_epochs * trainer.limit_train_batches + + # check call counts + assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) + assert self.transfer_weights_calls == 1 + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1): + model = SwaTestModel(batchnorm=batchnorm) + swa_start = 2 + max_epochs = 5 + swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) + assert swa_callback.update_parameters_calls == 0 + assert swa_callback.transfer_weights_calls == 0 + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=max_epochs, + limit_train_batches=5, + limit_val_batches=0, + callbacks=[swa_callback], + accumulate_grad_batches=2, + accelerator=accelerator, + gpus=gpus, + num_processes=num_processes + ) + trainer.fit(model) + + # check the model is the expected + assert trainer.lightning_module == model + + +@RunIf(min_gpus=2, min_torch="1.6.0", special=True) +def test_swa_callback_ddp(tmpdir): + train_with_swa(tmpdir, accelerator="ddp", gpus=2) + + +@RunIf(min_gpus=2, min_torch="1.6.0") +def test_swa_callback_ddp_spawn(tmpdir): + train_with_swa(tmpdir, accelerator="ddp_spawn", gpus=2) + + +@RunIf(min_torch="1.6.0", skip_windows=True) +def test_swa_callback_ddp_cpu(tmpdir): + train_with_swa(tmpdir, accelerator="ddp_cpu", num_processes=2) + + +@RunIf(min_gpus=1, min_torch="1.6.0") +def test_swa_callback_1_gpu(tmpdir): + train_with_swa(tmpdir, gpus=1) + + +@RunIf(min_torch="1.6.0") +@pytest.mark.parametrize("batchnorm", (True, False)) +def test_swa_callback(tmpdir, batchnorm: bool): + train_with_swa(tmpdir, batchnorm=batchnorm) + + +@RunIf(min_torch="1.6.0") +def test_swa_raises(): + with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): + StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1) + with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): + StochasticWeightAveraging(swa_epoch_start=1.5, swa_lrs=0.1) + with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): + StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1) + with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"): + StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1]) + + +@pytest.mark.parametrize('stochastic_weight_avg', [False, True]) +@pytest.mark.parametrize('use_callbacks', [False, True]) +@RunIf(min_torch="1.6.0") +def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool): + """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer""" + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None, + stochastic_weight_avg=stochastic_weight_avg, + limit_train_batches=4, + limit_val_batches=4, + max_epochs=2, + ) + trainer.fit(model) + if use_callbacks or stochastic_weight_avg: + assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1 + assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1) + else: + assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks) diff --git a/tests/checkpointing/__init__.py b/tests/checkpointing/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py new file mode 100644 index 00000000000000..7926bc46dd290e --- /dev/null +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -0,0 +1,144 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import callbacks, seed_everything, Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_mc_called(tmpdir): + seed_everything(1234) + + # ----------------- + # TRAIN LOOP ONLY + # ----------------- + train_step_only_model = BoringModel() + train_step_only_model.validation_step = None + + # no callback + trainer = Trainer(max_epochs=3, checkpoint_callback=False) + trainer.fit(train_step_only_model) + assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + + # ----------------- + # TRAIN + VAL LOOP ONLY + # ----------------- + val_train_model = BoringModel() + # no callback + trainer = Trainer(max_epochs=3, checkpoint_callback=False) + trainer.fit(val_train_model) + assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + + +@mock.patch('torch.save') +@pytest.mark.parametrize( + ['epochs', 'val_check_interval', 'expected'], + [(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)], +) +def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int): + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=epochs, + weights_summary=None, + val_check_interval=val_check_interval, + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + + # make sure types are correct + assert save_mock.call_count == expected + + +@mock.patch('torch.save') +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [ + (1, 1, 1.0, 1), + (2, 2, 1.0, 2), + (2, 1, 0.25, 4), + (2, 2, 0.3, 7), +]) +def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.last_coeff = 10.0 + + def training_step(self, batch, batch_idx): + loss = self.step(torch.ones(32)) + loss = loss / (loss + 0.0000001) + loss += self.last_coeff + self.log('my_loss', loss) + self.last_coeff *= 0.999 + return loss + + model = TestModel() + trainer = Trainer( + callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k)], + default_root_dir=tmpdir, + max_epochs=epochs, + weights_summary=None, + val_check_interval=val_check_interval + ) + trainer.fit(model) + + # make sure types are correct + assert save_mock.call_count == expected + + +@mock.patch('torch.save') +@RunIf(special=True, min_gpus=2) +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) +def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + local_rank = int(os.getenv("LOCAL_RANK")) + self.log('my_loss', batch_idx * (1 + local_rank), on_epoch=True) + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + data = str(self.global_rank) + obj = [[data], (data, ), set(data)] + out = self.trainer.training_type_plugin.broadcast(obj) + assert obj == [[str(self.global_rank)], (str(self.global_rank), ), set(str(self.global_rank))] + assert out == [['0'], ('0', ), set('0')] + + model = TestModel() + trainer = Trainer( + callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")], + default_root_dir=tmpdir, + max_epochs=epochs, + weights_summary=None, + val_check_interval=val_check_interval, + accelerator="ddp", + gpus=2, + limit_train_batches=64, + limit_val_batches=32, + ) + if os.getenv("LOCAL_RANK") == "0": + with pytest.raises(UserWarning, match="The value associated to the key my_loss_epoch: [15.5, 31.0]"): + trainer.fit(model) + assert save_mock.call_count == expected + else: + trainer.fit(model) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py new file mode 100644 index 00000000000000..fc711d1909eac1 --- /dev/null +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import os +import sys + +import pytest + +from pytorch_lightning import Trainer +from tests import PATH_LEGACY + +LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') +CHECKPOINT_EXTENSION = ".ckpt" + + +# todo: add more legacy checkpoints - for < v0.8 +@pytest.mark.parametrize( + "pl_version", + [ + # "0.8.1", + "0.8.3", + "0.8.4", + # "0.8.5", # this version has problem with loading on PT<=1.4 as it seems to be archive + # "0.9.0", # this version has problem with loading on PT<=1.4 as it seems to be archive + "0.10.0", + "1.0.0", + "1.0.1", + "1.0.2", + "1.0.3", + "1.0.4", + "1.0.5", + "1.0.6", + "1.0.7", + "1.0.8", + "1.1.0", + "1.1.1", + "1.1.2", + "1.1.3", + "1.1.4", + "1.1.5", + "1.1.6", + "1.1.7", + "1.1.8", + "1.2.0", + "1.2.1", + "1.2.2", + "1.2.3", + "1.2.4", + "1.2.5", + ] +) +def test_resume_legacy_checkpoints(tmpdir, pl_version: str): + path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + + # todo: make this as mock, so it is cleaner... + orig_sys_paths = list(sys.path) + sys.path.insert(0, path_dir) + from zero_training import DummyModel + + path_ckpts = sorted(glob.glob(os.path.join(path_dir, f'*{CHECKPOINT_EXTENSION}'))) + assert path_ckpts, 'No checkpoints found in folder "%s"' % path_dir + path_ckpt = path_ckpts[-1] + + model = DummyModel.load_from_checkpoint(path_ckpt) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=6) + result = trainer.fit(model) + assert result + + # todo + # model = DummyModel() + # trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=path_ckpt) + # result = trainer.fit(model) + # assert result + + sys.path = orig_sys_paths diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py new file mode 100644 index 00000000000000..5e50543d37e5fd --- /dev/null +++ b/tests/checkpointing/test_model_checkpoint.py @@ -0,0 +1,1237 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +import os +import pickle +import re +from argparse import Namespace +from logging import INFO +from pathlib import Path +from typing import Union +from unittest import mock +from unittest.mock import Mock + +import cloudpickle +import pytest +import torch +import yaml +from omegaconf import Container, OmegaConf +from torch import optim + +import pytorch_lightning as pl +import tests.helpers.utils as tutils +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class LogInTwoMethods(BoringModel): + + def training_step(self, batch, batch_idx): + out = super().training_step(batch, batch_idx) + self.log('early_stop_on', out['loss']) + return out + + def validation_epoch_end(self, outputs): + outs = torch.stack([x['x'] for x in outputs]).mean() + self.log('val_acc', outs) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize( + "validation_step_none,val_dataloaders_none,monitor", + [ + (False, False, 'val_log'), + (False, False, 'train_log_epoch'), + (True, False, 'train_log_epoch'), + (False, True, 'train_log_epoch'), + ], +) +@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True]) +def test_model_checkpoint_score_and_ckpt( + tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool +): + """ + Test that when a model checkpoint is saved, it saves with + the correct score appended to ckpt_path and checkpoint data + """ + max_epochs = 3 + limit_train_batches = 5 + limit_val_batches = 7 + lr = 1e-1 + + class CustomBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) + self.val_logs = torch.randn(max_epochs, limit_val_batches) + + def training_step(self, batch, batch_idx): + log_value = self.train_log_epochs[self.current_epoch, batch_idx] + self.log('train_log', log_value, on_epoch=True) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + log_value = self.val_logs[self.current_epoch, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return super().validation_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=lr) + + if reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + + return [optimizer], [lr_scheduler] + + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' + checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) + + model = CustomBoringModel() + + if validation_step_none: + model.validation_step = None + if val_dataloaders_none: + model.val_dataloaders = None + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint], + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + max_epochs=max_epochs, + progress_bar_refresh_rate=0, + ) + results = trainer.fit(model) + assert results + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + ckpt_files = list(Path(tmpdir).glob('*.ckpt')) + scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates + assert len(ckpt_files) == len(scores) == max_epochs + assert len(lr_scheduler_debug) == max_epochs + + for epoch in range(max_epochs): + score = scores[epoch] + expected_score = getattr(model, f'{monitor}s')[epoch].mean().item() + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' + assert math.isclose(score, expected_score, rel_tol=1e-4) + + chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) + assert chk['epoch'] == epoch + 1 + assert chk['global_step'] == limit_train_batches * (epoch + 1) + + mc_specific_data = chk['callbacks'][type(checkpoint)] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score + + if not reduce_lr_on_plateau: + lr_scheduler_specific_data = chk['lr_schedulers'][0] + assert lr_scheduler_specific_data['_step_count'] == epoch + 2 + assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1)) + + assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) + assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize( + "val_check_interval,reduce_lr_on_plateau", + [ + (0.25, True), + (0.25, False), + (0.33, False), + ], +) +def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau): + """ + Test that when a model checkpoint is saved, it saves with the correct + score appended to ckpt_path and checkpoint data with val_check_interval + """ + max_epochs = 3 + limit_train_batches = 12 + limit_val_batches = 7 + lr = 1e-1 + monitor = 'val_log' + per_epoch_steps = int(limit_train_batches * val_check_interval) + per_epoch_call_count = limit_train_batches // per_epoch_steps + + class CustomBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches) + self.val_loop_count = 0 + + def validation_step(self, batch, batch_idx): + log_value = self.val_logs[self.val_loop_count, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return super().validation_step(batch, batch_idx) + + def validation_epoch_end(self, outputs): + self.val_loop_count += 1 + super().validation_epoch_end(outputs) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=lr) + + if reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + + return [optimizer], [lr_scheduler] + + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' + checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) + + model = CustomBoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint], + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + max_epochs=max_epochs, + val_check_interval=val_check_interval, + progress_bar_refresh_rate=0, + num_sanity_val_steps=0, + ) + results = trainer.fit(model) + assert results + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + ckpt_files = list(Path(tmpdir).glob('*.ckpt')) + scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates + assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs + assert len(lr_scheduler_debug) == max_epochs + + for epoch in range(max_epochs): + for ix in range(per_epoch_call_count): + global_ix = ix + per_epoch_call_count * epoch + score = scores[global_ix] + expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' + assert math.isclose(score, expected_score, rel_tol=1e-4) + + chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) + assert chk['epoch'] == epoch + 1 + assert chk['global_step'] == per_epoch_steps * (global_ix + 1) + + mc_specific_data = chk['callbacks'][type(checkpoint)] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score + + if not reduce_lr_on_plateau: + lr_scheduler_specific_data = chk['lr_schedulers'][0] + did_update = 1 if ix + 1 == per_epoch_call_count else 0 + assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update + assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update)) + + assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) + assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) + + +@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) +def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int): + """Test that dirpath=None in checkpoint callback is valid and that ckpt_path is set correctly""" + tutils.reset_seed() + model = LogInTwoMethods() + + checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}', save_top_k=save_top_k) + max_epochs = 2 + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint], + overfit_batches=0.20, + max_epochs=max_epochs, + ) + trainer.fit(model) + assert (checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints") + + if save_top_k == -1: + ckpt_files = os.listdir(checkpoint.dirpath) + expected_ckpt_files = [f'epoch={i}.ckpt' for i in range(max_epochs)] + assert len(ckpt_files) == len(expected_ckpt_files) == max_epochs + assert set(ckpt_files) == set(expected_ckpt_files) + + +@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) +def test_model_checkpoint_to_yaml(tmpdir, save_top_k: int): + """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ + tutils.reset_seed() + model = LogInTwoMethods() + + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=save_top_k) + + trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2) + trainer.fit(model) + + path_yaml = os.path.join(tmpdir, 'best_k_models.yaml') + checkpoint.to_yaml(path_yaml) + d = yaml.full_load(open(path_yaml, 'r')) + best_k = {k: v for k, v in checkpoint.best_k_models.items()} + assert d == best_k + + +@pytest.mark.parametrize( + "logger_version,expected", + [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")], +) +def test_model_checkpoint_path(tmpdir, logger_version: Union[None, int, str], expected: str): + """Test that "version_" prefix is only added when logger's version is an integer""" + tutils.reset_seed() + model = LogInTwoMethods() + logger = TensorBoardLogger(str(tmpdir), version=logger_version) + + trainer = Trainer( + default_root_dir=tmpdir, + overfit_batches=0.2, + max_epochs=2, + logger=logger, + ) + trainer.fit(model) + + ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name + assert ckpt_version == expected + + +def test_pickling(tmpdir): + ckpt = ModelCheckpoint(dirpath=tmpdir) + + ckpt_pickled = pickle.dumps(ckpt) + ckpt_loaded = pickle.loads(ckpt_pickled) + assert vars(ckpt) == vars(ckpt_loaded) + + ckpt_pickled = cloudpickle.dumps(ckpt) + ckpt_loaded = cloudpickle.loads(ckpt_pickled) + assert vars(ckpt) == vars(ckpt_loaded) + + +class ModelCheckpointTestInvocations(ModelCheckpoint): + # this class has to be defined outside the test function, otherwise we get pickle error + # due to the way ddp process is launched + + def __init__(self, expected_count, *args, **kwargs): + super().__init__(*args, **kwargs) + self.expected_count = expected_count + self.on_save_checkpoint_count = 0 + + def on_train_start(self, trainer, pl_module): + torch.save = Mock(wraps=torch.save) + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + # only rank 0 will call ``torch.save`` + super().on_save_checkpoint(trainer, pl_module, checkpoint) + self.on_save_checkpoint_count += 1 + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + assert self.best_model_path + assert self.best_model_score + assert self.on_save_checkpoint_count == self.expected_count + if trainer.is_global_zero: + assert torch.save.call_count == self.expected_count + else: + assert torch.save.call_count == 0 + + +@RunIf(skip_windows=True) +def test_model_checkpoint_no_extraneous_invocations(tmpdir): + """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" + model = LogInTwoMethods() + num_epochs = 4 + model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) + trainer = Trainer( + accelerator="ddp_cpu", + num_processes=2, + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_epochs=num_epochs, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +def test_model_checkpoint_format_checkpoint_name(tmpdir): + # empty filename: + ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) + assert ckpt_name == 'epoch=3-step=2' + + ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') + assert ckpt_name == 'test-epoch=3-step=2' + + # no groups case: + ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') + assert ckpt_name == 'test-ckpt' + + # no prefix + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) + assert ckpt_name == 'epoch=003-acc=0.03' + + # prefix + char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR + ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') + assert ckpt_name == 'test@epoch=3,acc=0.03000' + ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org + + # no dirpath set + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) + assert ckpt_name == 'epoch=3-step=2.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) + assert ckpt_name == 'epoch=5-step=4.ckpt' + + # CWD + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') + + # with version + ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name') + ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3) + assert ckpt_name == tmpdir / 'name-v3.ckpt' + + # using slashes + ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}') + ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03}) + assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' + + # auto_insert_metric_name=False + ckpt_name = ModelCheckpoint._format_checkpoint_name( + 'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False + ) + assert ckpt_name == 'epoch=003-val_acc=0.03' + + +class ModelCheckpointExtensionTest(ModelCheckpoint): + FILE_EXTENSION = '.tpkc' + + +def test_model_checkpoint_file_extension(tmpdir): + """ + Test ModelCheckpoint with different file extension. + """ + + model = LogInTwoMethods() + model_checkpoint = ModelCheckpointExtensionTest( + monitor='early_stop_on', + dirpath=tmpdir, + save_top_k=1, + save_last=True, + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_steps=1, + logger=False, + ) + trainer.fit(model) + + expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + assert set(expected) == set(os.listdir(tmpdir)) + + +def test_model_checkpoint_save_last(tmpdir): + """Tests that save_last produces only one last checkpoint.""" + seed_everything() + model = LogInTwoMethods() + epochs = 3 + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' + model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=-1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_epochs=epochs, + limit_train_batches=10, + limit_val_batches=10, + logger=False, + ) + trainer.fit(model) + last_filename = model_checkpoint._format_checkpoint_name( + ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} + ) + last_filename = last_filename + '.ckpt' + assert str(tmpdir / last_filename) == model_checkpoint.last_model_path + assert set(os.listdir(tmpdir)) == set([f"epoch={i}-step={j}.ckpt" + for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename]) + + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' + + +def test_invalid_top_k(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative save_top_k argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be None or >= -1'): + ModelCheckpoint(dirpath=tmpdir, save_top_k=-3) + + +def test_none_monitor_top_k(tmpdir): + """ Test that a warning appears for positive top_k with monitor=None. """ + with pytest.raises( + MisconfigurationException, match=r'ModelCheckpoint\(save_top_k=3, monitor=None\) is not a valid*' + ): + ModelCheckpoint(dirpath=tmpdir, save_top_k=3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, save_top_k=None) + ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) + ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + + +def test_none_monitor_save_last(tmpdir): + """ Test that a warning appears for save_last=True with monitor=None. """ + with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): + ModelCheckpoint(dirpath=tmpdir, save_last=True) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, save_last=None) + ModelCheckpoint(dirpath=tmpdir, save_last=False) + + +def test_invalid_every_n_val_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): + """ + Test that a MisconfigurationException is raised if both + every_n_val_epochs and every_n_train_steps are enabled together. + """ + with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) + + +def test_none_every_n_train_steps_val_epochs(tmpdir): + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) + assert checkpoint_callback.period == 1 + assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_train_steps == 0 + + +def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): + """ Test that it is possible to save all checkpoints when monitor=None. """ + seed_everything() + model = LogInTwoMethods() + + epochs = 2 + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + limit_train_batches=10, + limit_val_batches=10, + max_epochs=epochs, + logger=False, + ) + + with caplog.at_level(INFO): + trainer.fit(model) + assert "will duplicate the last checkpoint saved" in caplog.text + + # these should not be set if monitor is None + assert checkpoint_callback.monitor is None + assert checkpoint_callback.best_model_path == tmpdir / 'epoch=1-step=19.ckpt' + assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt' + assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_k_models == {} + assert checkpoint_callback.kth_best_model_path == '' + + # check that the correct ckpts were created + expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] + expected.append('last.ckpt') + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("period", list(range(4))) +def test_model_checkpoint_period(tmpdir, period: int): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): + """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=(2 * every_n_val_epochs), + period=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +def test_ckpt_every_n_train_steps(tmpdir): + """ Tests that the checkpoints are saved every n training steps. """ + + model = LogInTwoMethods() + every_n_train_steps = 16 + max_epochs = 2 + epoch_length = 64 + checkpoint_callback = ModelCheckpoint( + filename="{step}", + every_n_val_epochs=0, + every_n_train_steps=every_n_train_steps, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + progress_bar_refresh_rate=0, + callbacks=[checkpoint_callback], + logger=False, + ) + + trainer.fit(model) + expected = [ + f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) + ] + assert set(os.listdir(tmpdir)) == set(expected) + + +def test_model_checkpoint_topk_zero(tmpdir): + """ Test that no checkpoints are saved when save_top_k=0. """ + model = LogInTwoMethods() + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + trainer.fit(model) + # these should not be set if monitor is None + assert checkpoint_callback.monitor is None + assert checkpoint_callback.best_model_path == '' + assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_k_models == {} + assert checkpoint_callback.kth_best_model_path == '' + # check that only the last ckpt was created + assert os.listdir(tmpdir) == ['last.ckpt'] + assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt' + + +def test_model_checkpoint_topk_all(tmpdir): + """ Test that save_top_k=-1 tracks the best models when monitor key is provided. """ + seed_everything(1000) + epochs = 3 + + model = BoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename="{epoch}", + monitor="epoch", + mode='max', + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + logger=False, + val_check_interval=1.0, + ) + trainer.fit(model) + + assert checkpoint_callback.monitor == 'epoch' + assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt" + assert checkpoint_callback.best_model_score == epochs - 1 + assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs + assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)) + assert checkpoint_callback.kth_best_model_path == tmpdir / 'epoch=0.ckpt' + + +def test_ckpt_metric_names(tmpdir): + model = LogInTwoMethods() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gradient_clip_val=1.0, + overfit_batches=0.20, + progress_bar_refresh_rate=0, + limit_train_batches=0.01, + limit_val_batches=0.01, + callbacks=[ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename="{val_loss:.2f}")], + ) + + trainer.fit(model) + + # make sure the checkpoint we saved has the metric in the name + ckpts = os.listdir(tmpdir) + ckpts = [x for x in ckpts if "val_loss" in x] + assert len(ckpts) == 1 + val = re.sub("[^0-9.]", "", ckpts[0]) + assert len(val) > 3 + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_default_checkpoint_behavior(tmpdir): + seed_everything(1234) + + model = LogInTwoMethods() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + progress_bar_refresh_rate=0, + limit_train_batches=5, + limit_val_batches=5, + ) + + trainer.fit(model) + results = trainer.test() + + assert len(results) == 1 + assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 + + # make sure the checkpoint we saved has the metric in the name + ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) + assert len(ckpts) == 1 + assert ckpts[0] == 'epoch=2-step=14.ckpt' + + +@pytest.mark.parametrize('max_epochs', [1, 2]) +@pytest.mark.parametrize('should_validate', [True, False]) +@pytest.mark.parametrize('save_last', [True, False]) +@pytest.mark.parametrize('verbose', [True, False]) +def test_model_checkpoint_save_last_warning( + tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool +): + """Tests 'Saving latest checkpoint...' log""" + model = LogInTwoMethods() + if not should_validate: + model.validation_step = None + ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[ckpt], + max_epochs=max_epochs, + ) + with caplog.at_level(logging.INFO): + trainer.fit(model) + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + + +def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): + """ Tests that the save_last checkpoint contains the latest information. """ + seed_everything(100) + model = LogInTwoMethods() + num_epochs = 3 + model_checkpoint = ModelCheckpoint( + monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_epochs=num_epochs, + ) + trainer.fit(model) + + path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") + path_last = str(tmpdir / "last.ckpt") + assert path_last == model_checkpoint.last_model_path + assert os.path.isfile(path_last_epoch) + + ckpt_last_epoch = torch.load(path_last_epoch) + ckpt_last = torch.load(path_last) + assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) + + ch_type = type(model_checkpoint) + assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] + + # it is easier to load the model objects than to iterate over the raw dict of tensors + model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) + model_last = LogInTwoMethods.load_from_checkpoint(model_checkpoint.last_model_path) + for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): + assert w0.eq(w1).all() + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize('mode', ['min', 'max']) +def test_checkpointing_with_nan_as_first(tmpdir, mode: int): + monitor = [float('nan')] + monitor += [5, 7, 8] if mode == 'max' else [8, 7, 5] + + class CurrentModel(LogInTwoMethods): + + def validation_epoch_end(self, outputs): + val_loss = monitor[self.current_epoch] + self.log('abc', val_loss) + + model = CurrentModel() + + trainer = Trainer( + callbacks=[ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, dirpath=tmpdir)], + default_root_dir=tmpdir, + val_check_interval=1.0, + max_epochs=len(monitor), + ) + trainer.fit(model) + + # check that last one is also the best one + assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1 + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_checkpoint_repeated_strategy(tmpdir): + """ + This test validates that the checkpoint can be called when provided to callbacks list + """ + checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}") + + class ExtendedBoringModel(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("val_loss", loss) + + model = ExtendedBoringModel() + model.validation_epoch_end = None + trainer = Trainer( + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + callbacks=[checkpoint_callback], + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + assert os.listdir(tmpdir) == ['epoch=00.ckpt'] + + for idx in range(4): + # load from checkpoint + model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path) + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + resume_from_checkpoint=checkpoint_callback.best_model_path, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + trainer.test(model, verbose=False) + assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} + assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)} + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_checkpoint_repeated_strategy_extended(tmpdir): + """ + This test validates checkpoint can be called several times without + increasing internally its global step if nothing run. + """ + + class ExtendedBoringModel(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"val_loss": loss} + + def validation_epoch_end(self, *_): + ... + + def assert_trainer_init(trainer): + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 + + def get_last_checkpoint(ckpt_dir): + last = ckpt_dir.listdir(sort=True)[-1] + return str(last) + + def assert_checkpoint_content(ckpt_dir): + chk = pl_load(get_last_checkpoint(ckpt_dir)) + assert chk["epoch"] == epochs + assert chk["global_step"] == 4 + + def assert_checkpoint_log_dir(idx): + lightning_logs = tmpdir / 'lightning_logs' + actual = [d.basename for d in lightning_logs.listdir(sort=True)] + assert actual == [f'version_{i}' for i in range(idx + 1)] + assert len(ckpt_dir.listdir()) == epochs + + ckpt_dir = tmpdir / 'checkpoints' + checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) + epochs = 2 + limit_train_batches = 2 + trainer_config = dict( + default_root_dir=tmpdir, + max_epochs=epochs, + limit_train_batches=limit_train_batches, + limit_val_batches=3, + limit_test_batches=4, + callbacks=[checkpoint_cb], + ) + trainer = pl.Trainer(**trainer_config) + assert_trainer_init(trainer) + + model = ExtendedBoringModel() + trainer.fit(model) + assert trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs - 1 + assert_checkpoint_log_dir(0) + assert_checkpoint_content(ckpt_dir) + + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + + trainer.test(model) + assert trainer.current_epoch == epochs - 1 + + for idx in range(1, 5): + chk = get_last_checkpoint(ckpt_dir) + assert_checkpoint_content(ckpt_dir) + + # load from checkpoint + trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)] + trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk) + assert_trainer_init(trainer) + + model = ExtendedBoringModel() + + trainer.test(model) + assert not trainer.checkpoint_connector.has_trained + # resume_from_checkpoint is resumed when calling `.fit` + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 + + trainer.fit(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + assert_checkpoint_log_dir(idx) + + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + + +def test_configure_model_checkpoint(tmpdir): + """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ + kwargs = dict(default_root_dir=tmpdir) + callback1 = ModelCheckpoint() + callback2 = ModelCheckpoint() + + # no callbacks + trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs) + assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks) + assert trainer.checkpoint_callback is None + + # default configuration + trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs) + assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1 + assert isinstance(trainer.checkpoint_callback, ModelCheckpoint) + + # custom callback passed to callbacks list, checkpoint_callback=True is ignored + trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] + assert trainer.checkpoint_callback == callback1 + + # multiple checkpoint callbacks + trainer = Trainer(callbacks=[callback1, callback2], **kwargs) + assert trainer.checkpoint_callback == callback1 + assert trainer.checkpoint_callbacks == [callback1, callback2] + + with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): + Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) + + +def test_val_check_interval_checkpoint_files(tmpdir): + """ Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """ + model = LogInTwoMethods() + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=-1, + monitor="val_acc", + mode="max", + ) + trainer = Trainer( + default_root_dir=tmpdir, + val_check_interval=0.2, + max_epochs=1, + limit_train_batches=10, + callbacks=[model_checkpoint], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + files = {p.basename for p in tmpdir.listdir()} + assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]} + + +def test_current_score(tmpdir): + """ Check that the current_score value is correct and was saved """ + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log("foo", (self.current_epoch + 1) / 10) + return super().training_step(*args) + + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=3, + monitor="foo", + mode="min", + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[model_checkpoint], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(TestModel()) + assert model_checkpoint.current_score == 0.3 + ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] + ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] + assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] + + +@pytest.mark.parametrize("mode", ["min", "max"]) +def test_current_score_when_nan(tmpdir, mode: str): + """ Check that ModelCheckpoint handles NaN values correctly """ + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log("foo", float("nan")) + return super().training_step(*args) + + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=1, + monitor="foo", + mode=mode, + ) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[model_checkpoint], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(TestModel()) + expected = float("inf" if mode == "min" else "-inf") + assert model_checkpoint.best_model_score == expected + assert model_checkpoint.current_score == expected + + +@pytest.mark.parametrize("hparams_type", [dict, Container]) +def test_hparams_type(tmpdir, hparams_type): + + class TestModel(BoringModel): + + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=1, + monitor="foo", + ) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[model_checkpoint], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + hp = {"test_hp_0": 1, "test_hp_1": 2} + hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp) + model = TestModel(hp) + trainer.fit(model) + ckpt = trainer.checkpoint_connector.dump_checkpoint() + if hparams_type == Container: + assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type) + else: + # make sure it's not AttributeDict + assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type + + +def test_ckpt_version_after_rerun_new_trainer(tmpdir): + """ + Check that previous checkpoints are renamed to have the correct + version suffix when new trainer instances are used + """ + epochs = 2 + for i in range(epochs): + mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}") + trainer = Trainer( + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + default_root_dir=tmpdir, + callbacks=[mc], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(BoringModel()) + + # check best_k_models state + expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"} + assert {Path(f).name for f in mc.best_k_models.keys()} == expected + + # check created ckpts + assert set(f.basename for f in tmpdir.listdir()) == { + "epoch=0.ckpt", + "epoch=1.ckpt", + "epoch=0-v1.ckpt", + "epoch=1-v1.ckpt", + } + + +def test_ckpt_version_after_rerun_same_trainer(tmpdir): + """ + Check that previous checkpoints are renamed to have the correct + version suffix when the same trainer instance is used + """ + mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test") + mc.STARTING_VERSION = 9 + trainer = Trainer( + max_epochs=2, + limit_train_batches=1, + limit_val_batches=1, + default_root_dir=tmpdir, + callbacks=[mc], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(BoringModel()) + trainer.max_epochs = 4 + trainer.fit(BoringModel()) + + ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION) + expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]} + # check best_k_models state + assert {Path(f).name for f in mc.best_k_models.keys()} == expected + # check created ckpts + assert set(os.listdir(tmpdir)) == expected + + +def test_model_checkpoint_mode_options(): + with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"): + ModelCheckpoint(mode="unknown_option") diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py new file mode 100644 index 00000000000000..8eabc4640046f3 --- /dev/null +++ b/tests/checkpointing/test_torch_saving.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +def test_model_torch_save(tmpdir): + """Test to ensure torch save does not fail for model and trainer.""" + model = BoringModel() + num_epochs = 1 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=num_epochs, + ) + temp_path = os.path.join(tmpdir, 'temp.pt') + trainer.fit(model) + + # Ensure these do not fail + torch.save(trainer.model, temp_path) + torch.save(trainer, temp_path) + trainer = torch.load(temp_path) + + +@RunIf(skip_windows=True) +def test_model_torch_save_ddp_cpu(tmpdir): + """Test to ensure torch save does not fail for model and trainer using cpu ddp.""" + model = BoringModel() + num_epochs = 1 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=num_epochs, + accelerator="ddp_cpu", + num_processes=2, + logger=False, + ) + temp_path = os.path.join(tmpdir, 'temp.pt') + trainer.fit(model) + + # Ensure these do not fail + torch.save(trainer.model, temp_path) + torch.save(trainer, temp_path) + + +@RunIf(min_gpus=2) +def test_model_torch_save_ddp_cuda(tmpdir): + """Test to ensure torch save does not fail for model and trainer using gpu ddp.""" + model = BoringModel() + num_epochs = 1 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=num_epochs, + accelerator="ddp_spawn", + gpus=2, + ) + temp_path = os.path.join(tmpdir, 'temp.pt') + trainer.fit(model) + + # Ensure these do not fail + torch.save(trainer.model, temp_path) + torch.save(trainer, temp_path) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py new file mode 100644 index 00000000000000..393f01fac0c18d --- /dev/null +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy + +import torch + +import pytorch_lightning as pl +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from tests.helpers import BoringModel + + +def test_finetuning_with_resume_from_checkpoint(tmpdir): + """ + This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test + """ + + seed_everything(3) + + checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) + + class ExtendedBoringModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + + model = ExtendedBoringModel() + model.validation_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=12, + limit_val_batches=6, + limit_test_batches=12, + callbacks=[checkpoint_callback], + logger=False, + ) + trainer.fit(model) + assert os.listdir(tmpdir) == ['epoch=00.ckpt'] + + best_model_paths = [checkpoint_callback.best_model_path] + results = [] + + for idx in range(3, 6): + # load from checkpoint + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=idx, + limit_train_batches=12, + limit_val_batches=12, + limit_test_batches=12, + resume_from_checkpoint=best_model_paths[-1], + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + trainer.test() + results.append(deepcopy(trainer.callback_metrics)) + best_model_paths.append(trainer.checkpoint_callback.best_model_path) + + for idx in range(len(results) - 1): + assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] + + for idx, best_model_path in enumerate(best_model_paths): + if idx == 0: + assert best_model_path.endswith(f"epoch=0{idx}.ckpt") + else: + assert f"epoch={idx + 1}" in best_model_path diff --git a/tests/collect_env_details.py b/tests/collect_env_details.py index ef34948f8e176e..2b8c4b3fafeed3 100644 --- a/tests/collect_env_details.py +++ b/tests/collect_env_details.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Diagnose your system and show basic information This server mainly to get detail info for better bug reporting. @@ -5,12 +18,11 @@ """ import os +import platform import re import sys -import platform import numpy -import tensorboard import torch import tqdm @@ -61,7 +73,6 @@ def info_packages(): "pyTorch_version": torch.__version__, 'pyTorch_debug': torch.version.debug, 'pytorch-lightning': pytorch_lightning.__version__, - 'tensorboard': tensorboard.__version__, 'tqdm': tqdm.__version__, } diff --git a/tests/conftest.py b/tests/conftest.py index 8eb3444ddaaba4..9bc607e119451f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,36 @@ -from functools import wraps +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import threading +from functools import partial, wraps +from http.server import SimpleHTTPRequestHandler import pytest import torch.multiprocessing as mp +@pytest.fixture(scope="function", autouse=True) +def restore_env_variables(): + """ Ensures that environment variables set during the test do not leak out. """ + env_backup = os.environ.copy() + yield + # restore environment as it was before running the test + os.environ.clear() + os.environ.update(env_backup) + + def pytest_configure(config): config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn") @@ -17,3 +44,38 @@ def pytest_pyfunc_call(pyfuncitem): mp.spawn(wraps, (testfunction, testargs)) return True + + +@pytest.fixture +def tmpdir_server(tmpdir): + if sys.version_info >= (3, 7): + Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir)) + from http.server import ThreadingHTTPServer + else: + # unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6 + # so we have to hack it like this + + class Handler(SimpleHTTPRequestHandler): + + def translate_path(self, path): + # get the path from cwd + path = super().translate_path(path) + # get the relative path + relpath = os.path.relpath(path, os.getcwd()) + # return the full path from root_dir + return os.path.join(str(tmpdir), relpath) + + # ThreadingHTTPServer was added in 3.7, so we need to define it ourselves + from http.server import HTTPServer + from socketserver import ThreadingMixIn + + class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): + daemon_threads = True + + with ThreadingHTTPServer(('localhost', 0), Handler) as server: + server_thread = threading.Thread(target=server.serve_forever) + # Exit the server thread when the main thread terminates + server_thread.daemon = True + server_thread.start() + yield server.server_address + server.shutdown() diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py new file mode 100644 index 00000000000000..c8808ec37326c8 --- /dev/null +++ b/tests/core/test_datamodules.py @@ -0,0 +1,499 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle +from argparse import ArgumentParser +from typing import Any, Dict +from unittest import mock +from unittest.mock import PropertyMock + +import pytest +import torch + +from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.model_helpers import is_overridden +from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel +from tests.helpers.utils import reset_seed + + +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) +def test_can_prepare_data(local_rank, node_rank): + + dm = BoringDataModule() + trainer = Trainer() + trainer.datamodule = dm + + # 1 no DM + # prepare_data_per_node = True + # local rank = 0 (True) + trainer.prepare_data_per_node = True + + local_rank.return_value = 0 + assert trainer.local_rank == 0 + assert trainer.data_connector.can_prepare_data() + + # local rank = 1 (False) + local_rank.return_value = 1 + assert trainer.local_rank == 1 + assert not trainer.data_connector.can_prepare_data() + + # prepare_data_per_node = False (prepare across all nodes) + # global rank = 0 (True) + trainer.prepare_data_per_node = False + node_rank.return_value = 0 + local_rank.return_value = 0 + assert trainer.data_connector.can_prepare_data() + + # global rank = 1 (False) + node_rank.return_value = 1 + local_rank.return_value = 0 + assert not trainer.data_connector.can_prepare_data() + node_rank.return_value = 0 + local_rank.return_value = 1 + assert not trainer.data_connector.can_prepare_data() + + # 2 dm + # prepar per node = True + # local rank = 0 (True) + trainer.prepare_data_per_node = True + local_rank.return_value = 0 + + # is_overridden prepare data = True + # has been called + # False + dm._has_prepared_data = True + assert not trainer.data_connector.can_prepare_data() + + # has not been called + # True + dm._has_prepared_data = False + assert trainer.data_connector.can_prepare_data() + + # is_overridden prepare data = False + # True + dm.prepare_data = None + assert trainer.data_connector.can_prepare_data() + + +def test_hooks_no_recursion_error(tmpdir): + # hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error. + # See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652 + class DummyDM(LightningDataModule): + + def setup(self, *args, **kwargs): + pass + + def prepare_data(self, *args, **kwargs): + pass + + for i in range(1005): + dm = DummyDM() + dm.setup() + dm.prepare_data() + + +def test_helper_boringdatamodule(tmpdir): + dm = BoringDataModule() + dm.prepare_data() + dm.setup() + + +def test_helper_boringdatamodule_with_verbose_setup(tmpdir): + dm = BoringDataModule() + dm.prepare_data() + dm.setup('fit') + dm.setup('test') + + +def test_data_hooks_called(tmpdir): + dm = BoringDataModule() + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.prepare_data() + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.setup() + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.teardown() + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict + assert dm.has_teardown_fit + assert dm.has_teardown_test + assert dm.has_teardown_validate + assert not dm.has_teardown_predict + + +@pytest.mark.parametrize("use_kwarg", (False, True)) +def test_data_hooks_called_verbose(tmpdir, use_kwarg): + dm = BoringDataModule() + dm.prepare_data() + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.setup(stage='fit') if use_kwarg else dm.setup('fit') + assert dm.has_setup_fit + assert not dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='validate') if use_kwarg else dm.setup('validate') + assert dm.has_setup_fit + assert dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='test') if use_kwarg else dm.setup('test') + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='predict') if use_kwarg else dm.setup('predict') + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict + + dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') + assert dm.has_teardown_fit + assert not dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='test') if use_kwarg else dm.teardown('test') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert dm.has_teardown_predict + + +def test_dm_add_argparse_args(tmpdir): + parser = ArgumentParser() + parser = BoringDataModule.add_argparse_args(parser) + args = parser.parse_args(['--data_dir', str(tmpdir)]) + assert args.data_dir == str(tmpdir) + + +def test_dm_init_from_argparse_args(tmpdir): + parser = ArgumentParser() + parser = BoringDataModule.add_argparse_args(parser) + args = parser.parse_args(['--data_dir', str(tmpdir)]) + dm = BoringDataModule.from_argparse_args(args) + dm.prepare_data() + dm.setup() + assert dm.data_dir == args.data_dir == str(tmpdir) + + +def test_dm_pickle_after_init(tmpdir): + dm = BoringDataModule() + pickle.dumps(dm) + + +def test_train_loop_only(tmpdir): + reset_seed() + + dm = ClassifDataModule() + model = ClassificationModel() + + model.validation_step = None + model.validation_step_end = None + model.validation_epoch_end = None + model.test_step = None + model.test_step_end = None + model.test_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + + # fit model + result = trainer.fit(model, datamodule=dm) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert result + assert trainer.callback_metrics['train_loss'] < 1.0 + + +def test_train_val_loop_only(tmpdir): + reset_seed() + + dm = ClassifDataModule() + model = ClassificationModel() + + model.validation_step = None + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + + # fit model + result = trainer.fit(model, datamodule=dm) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert result + assert trainer.callback_metrics['train_loss'] < 1.0 + + +def test_dm_checkpoint_save(tmpdir): + + class CustomBoringModel(BoringModel): + + def validation_step(self, batch, batch_idx): + out = super().validation_step(batch, batch_idx) + self.log('early_stop_on', out['x']) + return out + + class CustomBoringDataModule(BoringDataModule): + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + checkpoint[self.__class__.__name__] = self.__class__.__name__ + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.checkpoint_state = checkpoint.get(self.__class__.__name__) + + reset_seed() + dm = CustomBoringDataModule() + model = CustomBoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=1, + weights_summary=None, + callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on')], + ) + + # fit model + trainer.fit(model, dm) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] + checkpoint = torch.load(checkpoint_path) + assert dm.__class__.__name__ in checkpoint + assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + + +def test_full_loop(tmpdir): + reset_seed() + + dm = ClassifDataModule() + model = ClassificationModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + deterministic=True, + ) + + # fit model + result = trainer.fit(model, dm) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert dm.trainer is not None + assert result + + # validate + result = trainer.validate(datamodule=dm) + assert dm.trainer is not None + assert result[0]['val_acc'] > 0.7 + + # test + result = trainer.test(datamodule=dm) + assert dm.trainer is not None + assert result[0]['test_acc'] > 0.6 + + +@RunIf(min_gpus=1) +@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +def test_dm_apply_batch_transfer_handler(get_module_mock): + expected_device = torch.device('cuda', 0) + + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + self.targets = data[1] + + class CurrentTestDM(LightningDataModule): + rank = 0 + transfer_batch_to_device_hook_rank = None + on_before_batch_transfer_hook_rank = None + on_after_batch_transfer_hook_rank = None + + def on_before_batch_transfer(self, batch, dataloader_idx): + self.on_before_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.samples += 1 + return batch + + def on_after_batch_transfer(self, batch, dataloader_idx): + assert batch.samples.device == batch.targets.device == expected_device + self.on_after_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.targets *= 2 + return batch + + def transfer_batch_to_device(self, batch, device): + self.transfer_batch_to_device_hook_rank = self.rank + self.rank += 1 + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch + + dm = CurrentTestDM() + model = BoringModel() + + batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) + + trainer = Trainer(gpus=1) + # running .fit() would require us to implement custom data loaders, we mock the model reference instead + get_module_mock.return_value = model + if is_overridden('transfer_batch_to_device', dm): + model.transfer_batch_to_device = dm.transfer_batch_to_device + + model.on_before_batch_transfer = dm.on_before_batch_transfer + model.transfer_batch_to_device = dm.transfer_batch_to_device + model.on_after_batch_transfer = dm.on_after_batch_transfer + + batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) + + assert dm.on_before_batch_transfer_hook_rank == 0 + assert dm.transfer_batch_to_device_hook_rank == 1 + assert dm.on_after_batch_transfer_hook_rank == 2 + assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) + + +def test_dm_reload_dataloaders_every_epoch(tmpdir): + """Test datamodule, where trainer argument + reload_dataloaders_every_epoch is set to True/False""" + + class CustomBoringDataModule(BoringDataModule): + + def __init__(self): + super().__init__() + self._epochs_called_for = [] + + def train_dataloader(self): + assert self.trainer.current_epoch not in self._epochs_called_for + self._epochs_called_for.append(self.trainer.current_epoch) + return super().train_dataloader() + + dm = CustomBoringDataModule() + model = BoringModel() + + model.validation_step = None + model.validation_step_end = None + model.validation_epoch_end = None + model.test_step = None + model.test_step_end = None + model.test_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=0.01, + reload_dataloaders_every_epoch=True, + ) + trainer.fit(model, dm) + + +class DummyDS(torch.utils.data.Dataset): + + def __getitem__(self, index): + return 1 + + def __len__(self): + return 100 + + +def test_dm_init_from_datasets(tmpdir): + + train_ds = DummyDS() + valid_ds = DummyDS() + test_ds = DummyDS() + + valid_dss = [DummyDS(), DummyDS()] + test_dss = [DummyDS(), DummyDS()] + + dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) + assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4)) + assert dm.val_dataloader() is None + assert dm.test_dataloader() is None + + dm = LightningDataModule.from_datasets(train_ds, valid_ds, test_ds, batch_size=4, num_workers=0) + assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4)) + assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4)) + + dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0) + assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4)) + assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4)) + assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4)) + assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4)) diff --git a/tests/core/test_decorators.py b/tests/core/test_decorators.py new file mode 100644 index 00000000000000..7d423f34cda585 --- /dev/null +++ b/tests/core/test_decorators.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from pytorch_lightning.core.decorators import auto_move_data +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize(['src_device', 'dest_device'], [ + pytest.param(torch.device('cpu'), torch.device('cpu')), + pytest.param(torch.device('cpu', 0), torch.device('cuda', 0)), + pytest.param(torch.device('cuda', 0), torch.device('cpu')), + pytest.param(torch.device('cuda', 0), torch.device('cuda', 0)), +]) +def test_auto_move_data(src_device, dest_device): + """ Test that the decorator moves the data to the device the model is on. """ + + class CurrentModel(BoringModel): + + @auto_move_data + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + model = CurrentModel() + model = model.to(dest_device) + model.prepare_data() + loader = model.train_dataloader() + x = next(iter(loader)) + + # test that data on source device gets moved to destination device + x = x.to(src_device) + assert model(x).device == dest_device, "Automoving data to same device as model failed" diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py new file mode 100644 index 00000000000000..191da0a1400c70 --- /dev/null +++ b/tests/core/test_hooks.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +def test_on_val_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_test_epoch_end(self, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + weights_summary=None, + ) + + trainer.test(model) diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py new file mode 100644 index 00000000000000..8270867ae863c0 --- /dev/null +++ b/tests/core/test_lightning_module.py @@ -0,0 +1,360 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +import pytest +from torch import nn +from torch.optim import Adam, SGD + +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel + + +def test_property_current_epoch(): + """ Test that the current_epoch in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.current_epoch == 0 + + trainer = Mock(current_epoch=123) + model.trainer = trainer + assert model.current_epoch == 123 + + +def test_property_global_step(): + """ Test that the global_step in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.global_step == 0 + + trainer = Mock(global_step=123) + model.trainer = trainer + assert model.global_step == 123 + + +def test_property_global_rank(): + """ Test that the global rank in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.global_rank == 0 + + trainer = Mock(global_rank=123) + model.trainer = trainer + assert model.global_rank == 123 + + +def test_property_local_rank(): + """ Test that the local rank in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.local_rank == 0 + + trainer = Mock(local_rank=123) + model.trainer = trainer + assert model.local_rank == 123 + + +def test_property_logger(tmpdir): + """ Test that the logger in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.logger is None + + logger = TensorBoardLogger(tmpdir) + trainer = Mock(logger=logger) + model.trainer = trainer + assert model.logger == logger + + +def test_automatic_optimization_raises(tmpdir): + + class TestModel(BoringModel): + + def optimizer_step(self, *_, **__): + pass + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + accumulate_grad_batches=2, + ) + + with pytest.raises( + MisconfigurationException, match='overriding .* optimizer_step .* `accumulate_grad_batches` .* should be 1' + ): + trainer.fit(model) + + +def test_params_groups_and_state_are_accessible(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu=False, + using_native_amp=False, + using_lbfgs=False + ): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * 0.01 + + optimizer.step(closure=optimizer_closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + limit_val_batches=1, + accumulate_grad_batches=1, + ) + + trainer.fit(model) + + +def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + return super().training_step(batch, batch_idx) + + def __init__(self): + super().__init__() + self.layer_1 = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ) + + self.layer_2 = nn.Sequential( + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 2), + ) + + # set some weights to False to check untoggle works as expected. + self.layer_1[2].weight.requires_grad = False + self.layer_1[4].weight.requires_grad = False + + self.layer_2[1].weight.requires_grad = False + self.layer_2[3].weight.requires_grad = False + + def configure_optimizers(self): + optimizer = SGD(self.layer_1.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step( + self, + current_epoch, + batch_nb, + optimizer, + optimizer_idx, + closure, + on_tpu=False, + using_native_amp=False, + using_lbfgs=False + ): + if optimizer_idx == 0: + assert self.layer_1[0].weight.requires_grad is True + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is False + + if optimizer_idx == 1: + assert self.layer_1[0].weight.requires_grad is False + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is True + + optimizer.step(closure=closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + limit_val_batches=0, + ) + + results = trainer.fit(model) + assert results + + +def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer_1 = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ) + + self.layer_2 = nn.Sequential( + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 2), + ) + + self.layer_3 = nn.Sequential( + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 2), + ) + + # set some weights to False to check untoggle works as expected. + self.layer_1[2].weight.requires_grad = False + self.layer_1[4].weight.requires_grad = False + + self.layer_2[1].weight.requires_grad = False + self.layer_2[3].weight.requires_grad = False + + self.layer_3[1].weight.requires_grad = False + self.layer_3[5].weight.requires_grad = False + + def optimizer_step( + self, + current_epoch, + batch_nb, + optimizer, + optimizer_idx, + closure, + on_tpu=False, + using_native_amp=False, + using_lbfgs=False + ): + if optimizer_idx == 0: + assert self.layer_1[0].weight.requires_grad is True + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is True + + assert self.layer_3[1].weight.requires_grad is False + assert self.layer_3[3].weight.requires_grad is False + assert self.layer_3[5].weight.requires_grad is False + + if optimizer_idx == 1: + assert self.layer_1[0].weight.requires_grad is False + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is True + + assert self.layer_3[1].weight.requires_grad is False + assert self.layer_3[3].weight.requires_grad is True + assert self.layer_3[5].weight.requires_grad is False + + if optimizer_idx == 2: + assert self.layer_1[0].weight.requires_grad is True + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is False + + assert self.layer_3[1].weight.requires_grad is False + assert self.layer_3[3].weight.requires_grad is True + assert self.layer_3[5].weight.requires_grad is False + + optimizer.step(closure=closure) + + def training_step(self, batch, batch_idx, optimizer_idx=None): + loss = super().training_step(batch, batch_idx) + # make sure the model is untoggle when returning None + return loss if batch_idx % 2 == 0 else None + + @staticmethod + def combine_generators(gen_1, gen_2): + for p in gen_1: + yield p + for p in gen_2: + yield p + + def configure_optimizers(self): + optimizer_1 = SGD(self.combine_generators( + self.layer_1.parameters(), + self.layer_2.parameters(), + ), lr=0.1) + optimizer_2 = Adam(self.combine_generators( + self.layer_2.parameters(), + self.layer_3.parameters(), + ), lr=0.1) + optimizer_3 = SGD(self.combine_generators( + self.layer_3.parameters(), + self.layer_1.parameters(), + ), lr=0.1) + return [optimizer_1, optimizer_2, optimizer_3] + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + ) + + trainer.fit(model) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py new file mode 100644 index 00000000000000..8858129b221f91 --- /dev/null +++ b/tests/core/test_lightning_optimizer.py @@ -0,0 +1,383 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +from typing import Any +from unittest.mock import DEFAULT, patch + +import torch +from torch.optim import Adam, Optimizer, SGD + +from pytorch_lightning import Trainer +from pytorch_lightning.core.optimizer import LightningOptimizer +from tests.helpers.boring_model import BoringModel + + +def test_lightning_optimizer(tmpdir): + """ + Test that optimizer are correctly wrapped by our LightningOptimizer + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + groups = "{'dampening': 0, 'initial_lr': 0.1, 'lr': 0.01, 'momentum': 0, 'nesterov': False, 'weight_decay': 0}" + expected = f"LightningSGD(groups=[{groups}])" + assert trainer._lightning_optimizers[0].__repr__() == expected + + +def test_lightning_optimizer_from_user(tmpdir): + """ + Test that the user can use our LightningOptimizer. Not recommended. + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.1) + optimizer = LightningOptimizer(optimizer) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + groups = "{'amsgrad': False, 'betas': (0.9, 0.999), 'eps': 1e-08, 'initial_lr': 0.1, 'lr': 0.01, 'weight_decay': 0}" + expected = f"LightningAdam(groups=[{groups}])" + assert trainer._lightning_optimizers[0].__repr__() == expected + + +def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmpdir): + """ + Test that the user can use our LightningOptimizer. Not recommended. + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt_1, opt_2 = self.optimizers() + + assert isinstance(opt_1, LightningOptimizer) + assert isinstance(opt_2, LightningOptimizer) + + def closure(opt): + output = self.layer(batch) + loss = self.loss(batch, output) + opt.zero_grad() + self.manual_backward(loss) + + if batch_idx % 2 == 0: + closure(opt_1) + opt_1.step() + + closure(opt_2) + opt_2.step() + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + model.training_step_end = None + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=8, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + accumulate_grad_batches=999, # does not do anything if manual optimization + ) + + with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, \ + patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam: + trainer.fit(model) + + assert sgd["step"].call_count == 4 + assert adam["step"].call_count == 8 + + assert sgd["zero_grad"].call_count == 4 + assert adam["zero_grad"].call_count == 8 + + +def test_state(tmpdir): + model = torch.nn.Linear(3, 4) + optimizer = torch.optim.Adam(model.parameters()) + lightning_optimizer = LightningOptimizer(optimizer) + + # test state + assert optimizer.state == lightning_optimizer.state + lightning_optimizer.state = optimizer.state + assert optimizer.state == lightning_optimizer.state + + # test param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + lightning_optimizer.param_groups = optimizer.param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + + # test defaults + assert optimizer.defaults == lightning_optimizer.defaults + lightning_optimizer.defaults = optimizer.defaults + assert optimizer.defaults == lightning_optimizer.defaults + + assert isinstance(lightning_optimizer, LightningOptimizer) + assert isinstance(lightning_optimizer, Adam) + assert isinstance(lightning_optimizer, Optimizer) + + lightning_dict = {} + special_attrs = [ + "_accumulate_grad_batches", + "_optimizer", + "_optimizer_idx", + "_support_closure", + "_trainer", + "__getstate__", + "__setstate__", + "state_dict", + "load_state_dict", + "zero_grad", + "__setstate__", + "add_param_group", + "_total_optimizer_step_calls", + ] + + for k, v in lightning_optimizer.__dict__.items(): + if k not in special_attrs: + lightning_dict[k] = v + + assert lightning_dict == optimizer.__dict__ + assert optimizer.state_dict() == lightning_optimizer.state_dict() + assert optimizer.state == lightning_optimizer.state + + +def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): + """ + Test overriding zero_grad works in automatic_optimization + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs): + ... + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + if isinstance(optimizer, SGD) and batch_idx % 2 == 0: + optimizer.zero_grad() + if isinstance(optimizer, Adam) and batch_idx % 5 == 0: + optimizer.zero_grad() + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=20, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + + with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + trainer.fit(model) + + assert adam_zero_grad.call_count == 4 + assert sgd_zero_grad.call_count == 10 + + +def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir): + """ + Test overriding step works in automatic_optimization + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs): + ... + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): + assert optimizer_closure.__name__ == "train_step_and_backward_closure" + # not passing the closure to the optimizer because step is mocked + # zero_grad is called inside the closure + if isinstance(optimizer, SGD) and batch_idx % 2 == 0: + optimizer_closure() + optimizer.step() + if isinstance(optimizer, Adam) and batch_idx % 4 == 0: + optimizer_closure() + optimizer.step() # not passing the closure here because it's a mock + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.Adam(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + return [optimizer_1, optimizer_2], [lr_scheduler] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=8, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + + with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, \ + patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam: + trainer.fit(model) + + assert sgd["step"].call_count == 4 + assert adam["step"].call_count == 2 + + assert sgd["zero_grad"].call_count == 4 + assert adam["zero_grad"].call_count == 2 + + +def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir): + """ + Test zero_grad is called the same number of times as LBFGS requires + for reevaluation of the loss in automatic_optimization. + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + return torch.optim.LBFGS(self.parameters()) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + + with patch("torch.optim.LBFGS.zero_grad") as zero_grad: + trainer.fit(model) + + lbfgs = model.optimizers() + max_iter = lbfgs.param_groups[0]["max_iter"] + assert zero_grad.call_count == max_iter + + +class OptimizerWithHooks(Optimizer): + + def __init__(self, model): + self._fwd_handles = [] + self._bwd_handles = [] + self.params = [] + for _, mod in model.named_modules(): + mod_class = mod.__class__.__name__ + if mod_class != 'Linear': + continue + + handle = mod.register_forward_pre_hook(self._save_input) # save the inputs + self._fwd_handles.append(handle) # collect forward-save-input hooks in list + handle = mod.register_backward_hook(self._save_grad_output) # save the gradients + self._bwd_handles.append(handle) # collect backward-save-grad hook in list + + # save the parameters + params = [mod.weight] + if mod.bias is not None: + params.append(mod.bias) + + # save a param_group for each module + d = {'params': params, 'mod': mod, 'layer_type': mod_class} + self.params.append(d) + + super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01}) + + def _save_input(self, mod, i): + """Saves input of layer""" + if mod.training: + self.state[mod]['x'] = i[0] + + def _save_grad_output(self, mod, _, grad_output): + """ + Saves grad on output of layer to + grad is scaled with batch_size since gradient is spread over samples in mini batch + """ + batch_size = grad_output[0].shape[0] + if mod.training: + self.state[mod]['grad'] = grad_output[0] * batch_size + + def step(self, closure=None): + closure() + for group in self.param_groups: + _ = self.state[group['mod']]['x'] + _ = self.state[group['mod']]['grad'] + return True + + +def test_lightning_optimizer_keeps_hooks(tmpdir): + + class TestModel(BoringModel): + count_on_train_batch_start = 0 + count_on_train_batch_end = 0 + + def configure_optimizers(self): + return OptimizerWithHooks(self) + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.count_on_train_batch_start += 1 + optimizer = self.optimizers(use_pl_optimizer=False) + assert len(optimizer._fwd_handles) == 1 + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.count_on_train_batch_end += 1 + del self.trainer._lightning_optimizers + gc.collect() # not necessary, just in case + + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1) + model = TestModel() + trainer.fit(model) + assert model.count_on_train_batch_start == 4 + assert model.count_on_train_batch_end == 4 diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py new file mode 100644 index 00000000000000..3088743f714889 --- /dev/null +++ b/tests/core/test_memory.py @@ -0,0 +1,304 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.nn as nn + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.advanced_models import ParityModuleRNN +from tests.helpers.runif import RunIf + + +class EmptyModule(LightningModule): + """ A module that has no layers """ + + def __init__(self): + super().__init__() + self.parameter = torch.rand(3, 3, requires_grad=True) + self.example_input_array = torch.zeros(1, 2, 3, 4, 5) + + def forward(self, *args, **kwargs): + return {'loss': self.parameter.sum()} + + +class PreCalculatedModel(BoringModel): + """ A model with precalculated total params size in MB for FP16 and FP32. """ + + def __init__(self, precision: int = 32): + super().__init__() + # 32K params + self.layer = nn.Linear(32, 1000, bias=False) + # 218K params + self.layer1 = nn.Linear(1000, 218, bias=False) + # calculate model size based on precision. + self.pre_calculated_model_size = 1.0 / (32 / precision) + + def forward(self, x): + x = self.layer(x) + return self.layer1(x) + + +class UnorderedModel(LightningModule): + """ A model in which the layers not defined in order of execution """ + + def __init__(self): + super().__init__() + # note: the definition order is intentionally scrambled for this test + self.layer2 = nn.Linear(10, 2) + self.combine = nn.Linear(7, 9) + self.layer1 = nn.Linear(3, 5) + self.relu = nn.ReLU() + # this layer is unused, therefore input-/output shapes are unknown + self.unused = nn.Conv2d(1, 1, 1) + + self.example_input_array = (torch.rand(2, 3), torch.rand(2, 10)) + + def forward(self, x, y): + out1 = self.layer1(x) + out2 = self.layer2(y) + out = self.relu(torch.cat((out1, out2), 1)) + out = self.combine(out) + return out + + +class MixedDtypeModel(LightningModule): + """ The parameters and inputs of this model have different dtypes. """ + + def __init__(self): + super().__init__() + self.embed = nn.Embedding(10, 20) # expects dtype long as input + self.reduce = nn.Linear(20, 1) # dtype: float + self.example_input_array = torch.tensor([[0, 2, 1], [3, 5, 3]]) # dtype: long + + def forward(self, x): + return self.reduce(self.embed(x)) + + +class PartialScriptModel(LightningModule): + """ A model which contains scripted layers. """ + + def __init__(self): + super().__init__() + self.layer1 = torch.jit.script(nn.Linear(5, 3)) + self.layer2 = nn.Linear(3, 2) + self.example_input_array = torch.rand(2, 5) + + def forward(self, x): + return self.layer2(self.layer1(x)) + + +def test_invalid_weights_summmary(): + """ Test that invalid value for weights_summary raises an error. """ + with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'): + UnorderedModel().summarize(mode='temp') + + with pytest.raises(MisconfigurationException, match='`weights_summary` can be None, .* got temp'): + Trainer(weights_summary='temp') + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_empty_model_summary_shapes(mode: ModelSummary): + """ Test that the summary works for models that have no submodules. """ + model = EmptyModule() + summary = model.summarize(mode=mode) + assert summary.in_sizes == [] + assert summary.out_sizes == [] + assert summary.param_nums == [] + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize(['device'], [ + pytest.param(torch.device('cpu')), + pytest.param(torch.device('cuda', 0)), + pytest.param(torch.device('cuda', 0)), +]) +def test_linear_model_summary_shapes(device, mode): + """ Test that the model summary correctly computes the input- and output shapes. """ + model = UnorderedModel().to(device) + model.train() + summary = model.summarize(mode=mode) + assert summary.in_sizes == [ + [2, 10], # layer 2 + [2, 7], # combine + [2, 3], # layer 1 + [2, 7], # relu + UNKNOWN_SIZE, + ] + assert summary.out_sizes == [ + [2, 2], # layer 2 + [2, 9], # combine + [2, 5], # layer 1 + [2, 7], # relu + UNKNOWN_SIZE, + ] + assert model.training + assert model.device == device + + +def test_mixed_dtype_model_summary(): + """ Test that the model summary works with models that have mixed input- and parameter dtypes. """ + model = MixedDtypeModel() + summary = model.summarize() + assert summary.in_sizes == [ + [2, 3], # embed + [2, 3, 20], # reduce + ] + assert summary.out_sizes == [ + [2, 3, 20], # embed + [2, 3, 1], # reduce + ] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_hooks_removed_after_summarize(mode): + """ Test that all hooks were properly removed after summary, even ones that were not run. """ + model = UnorderedModel() + summary = ModelSummary(model, mode=mode) + # hooks should be removed + for _, layer in summary.summarize().items(): + handle = layer._hook_handle + assert handle.id not in handle.hooks_dict_ref() + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_rnn_summary_shapes(mode): + """ Test that the model summary works for RNNs. """ + model = ParityModuleRNN() + + b = 3 + t = 5 + i = model.rnn.input_size + h = model.rnn.hidden_size + o = model.linear_out.out_features + + model.example_input_array = torch.zeros(b, t, 10) + + summary = model.summarize(mode=mode) + assert summary.in_sizes == [ + [b, t, i], # rnn + [b, t, h], # linear + ] + assert summary.out_sizes == [ + [[b, t, h], [[1, b, h], [1, b, h]]], # rnn + [b, t, o] # linear + ] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_summary_parameter_count(mode): + """ Test that the summary counts the number of parameters in every submodule. """ + model = UnorderedModel() + summary = model.summarize(mode=mode) + assert summary.param_nums == [ + model.layer2.weight.numel() + model.layer2.bias.numel(), + model.combine.weight.numel() + model.combine.bias.numel(), + model.layer1.weight.numel() + model.layer1.bias.numel(), + 0, # ReLU + model.unused.weight.numel() + model.unused.bias.numel(), + ] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_summary_layer_types(mode): + """ Test that the summary displays the layer names correctly. """ + model = UnorderedModel() + summary = model.summarize(mode=mode) + assert summary.layer_types == [ + 'Linear', + 'Linear', + 'Linear', + 'ReLU', + 'Conv2d', + ] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_summary_with_scripted_modules(mode): + model = PartialScriptModel() + summary = model.summarize(mode=mode) + assert summary.layer_types == ["RecursiveScriptModule", "Linear"] + assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]] + assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize(['example_input', 'expected_size'], [ + pytest.param([], UNKNOWN_SIZE), + pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3), + pytest.param(torch.tensor(0), UNKNOWN_SIZE), + pytest.param(dict(tensor=torch.zeros(1, 2, 3)), UNKNOWN_SIZE), + pytest.param(torch.zeros(2, 3, 4), [2, 3, 4]), + pytest.param([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]), + pytest.param((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]), +]) +def test_example_input_array_types(example_input, expected_size, mode): + """ Test the types of example inputs supported for display in the summary. """ + + class DummyModule(nn.Module): + + def forward(self, *args, **kwargs): + return None + + class DummyLightningModule(LightningModule): + + def __init__(self): + super().__init__() + self.layer = DummyModule() + + # this LightningModule and submodule accept any type of input + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + model = DummyLightningModule() + model.example_input_array = example_input + summary = model.summarize(mode=mode) + assert summary.in_sizes == [expected_size] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_model_size(mode): + """ Test model size is calculated correctly. """ + model = PreCalculatedModel() + summary = model.summarize(mode=mode) + assert model.pre_calculated_model_size == summary.model_size + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_empty_model_size(mode): + """ Test empty model size is zero. """ + model = EmptyModule() + summary = model.summarize(mode=mode) + assert 0.0 == summary.model_size + + +@RunIf(min_gpus=1, amp_native=True) +def test_model_size_precision(tmpdir): + """ Test model size for half and full precision. """ + model = PreCalculatedModel() + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + max_steps=1, + max_epochs=1, + precision=32, + ) + trainer.fit(model) + summary = model.summarize() + assert model.pre_calculated_model_size == summary.model_size diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py new file mode 100644 index 00000000000000..0b797dff0e42f0 --- /dev/null +++ b/tests/core/test_metric_result_integration.py @@ -0,0 +1,140 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torchmetrics import Metric + +import tests.helpers.utils as tutils +from pytorch_lightning.core.step_result import Result +from tests.helpers.runif import RunIf + + +class DummyMetric(Metric): + + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0), dist_reduce_fx="sum") + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + +def _setup_ddp(rank, worldsize): + import os + + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def _ddp_test_fn(rank, worldsize): + _setup_ddp(rank, worldsize) + torch.tensor([1.0]) + + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + + # dist_sync_on_step is False by default + result = Result() + + for epoch in range(3): + cumulative_sum = 0 + + for i in range(5): + metric_a(i) + metric_b(i) + metric_c(i) + + cumulative_sum += i + + result.log('a', metric_a, on_step=True, on_epoch=True) + result.log('b', metric_b, on_step=False, on_epoch=True) + result.log('c', metric_c, on_step=True, on_epoch=False) + + batch_log = result.get_batch_log_metrics() + batch_expected = {"a_step": i, "a": i, "c": i} + assert set(batch_log.keys()) == set(batch_expected.keys()) + for k in batch_expected.keys(): + assert batch_expected[k] == batch_log[k] + + epoch_log = result.get_epoch_log_metrics() + + # assert metric state reset to default values + assert metric_a.x == metric_a._defaults['x'] + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] + + epoch_expected = {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize} + + assert set(epoch_log.keys()) == set(epoch_expected.keys()) + for k in epoch_expected.keys(): + assert epoch_expected[k] == epoch_log[k] + + +@RunIf(skip_windows=True) +def test_result_reduce_ddp(): + """Make sure result logging works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize) + + +def test_result_metric_integration(): + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + + result = Result() + + for epoch in range(3): + cumulative_sum = 0 + + for i in range(5): + metric_a(i) + metric_b(i) + metric_c(i) + + cumulative_sum += i + + result.log('a', metric_a, on_step=True, on_epoch=True) + result.log('b', metric_b, on_step=False, on_epoch=True) + result.log('c', metric_c, on_step=True, on_epoch=False) + + batch_log = result.get_batch_log_metrics() + batch_expected = {"a_step": i, "a": i, "c": i} + assert set(batch_log.keys()) == set(batch_expected.keys()) + for k in batch_expected.keys(): + assert batch_expected[k] == batch_log[k] + + epoch_log = result.get_epoch_log_metrics() + + # assert metric state reset to default values + assert metric_a.x == metric_a._defaults['x'] + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] + + epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum} + + assert set(epoch_log.keys()) == set(epoch_expected.keys()) + for k in epoch_expected.keys(): + assert epoch_expected[k] == epoch_log[k] diff --git a/tests/core/test_results.py b/tests/core/test_results.py new file mode 100644 index 00000000000000..9586344d8c0d91 --- /dev/null +++ b/tests/core/test_results.py @@ -0,0 +1,286 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.utils.data import DataLoader + +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.runif import RunIf + + +def _setup_ddp(rank, worldsize): + import os + + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def _ddp_test_fn(rank, worldsize, result_cls: Result): + _setup_ddp(rank, worldsize) + tensor = torch.tensor([1.0]) + + res = result_cls() + res.log("test_tensor", tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM) + + assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" + + +@RunIf(skip_windows=True) +def test_result_reduce_ddp(): + """Make sure result logging works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(_ddp_test_fn, args=(worldsize, Result), nprocs=worldsize) + + +@pytest.mark.parametrize( + "test_option,do_train,gpus", [ + pytest.param(0, True, 0, id='full_loop'), + pytest.param(0, False, 0, id='test_only'), + pytest.param( + 1, False, 0, id='test_only_mismatching_tensor', marks=pytest.mark.xfail(raises=ValueError, match="Mism.*") + ), + pytest.param(2, False, 0, id='mix_of_tensor_dims'), + pytest.param(3, False, 0, id='string_list_predictions'), + pytest.param(4, False, 0, id='int_list_predictions'), + pytest.param(5, False, 0, id='nested_list_predictions'), + pytest.param(6, False, 0, id='dict_list_predictions'), + pytest.param(7, True, 0, id='write_dict_predictions'), + pytest.param(0, True, 1, id='full_loop_single_gpu', marks=RunIf(min_gpus=1)) + ] +) +def test_result_obj_predictions(tmpdir, test_option: int, do_train: bool, gpus: int): + + class CustomBoringModel(BoringModel): + + def test_step(self, batch, batch_idx, optimizer_idx=None): + output = self(batch) + test_loss = self.loss(batch, output) + self.log('test_loss', test_loss) + + batch_size = batch.size(0) + lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)] + lst_of_int = [random.randint(500, 1000) for i in range(batch_size)] + lst_of_lst = [[x] for x in lst_of_int] + lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)] + + # This is passed in from pytest via parameterization + option = getattr(self, 'test_option', 0) + prediction_file = getattr(self, 'prediction_file', 'predictions.pt') + + lazy_ids = torch.arange(batch_idx * batch_size, batch_idx * batch_size + batch_size) + + # Base + if option == 0: + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('preds', output, prediction_file) + + # Check mismatching tensor len + elif option == 1: + self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) + self.write_prediction('preds', output, prediction_file) + + # write multi-dimension + elif option == 2: + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('preds', output, prediction_file) + self.write_prediction('x', batch, prediction_file) + + # write str list + elif option == 3: + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_str, prediction_file) + + # write int list + elif option == 4: + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_int, prediction_file) + + # write nested list + elif option == 5: + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_lst, prediction_file) + + # write dict list + elif option == 6: + self.write_prediction('idxs', lazy_ids, prediction_file) + self.write_prediction('vals', lst_of_dict, prediction_file) + + elif option == 7: + self.write_prediction_dict({'idxs': lazy_ids, 'preds': output}, prediction_file) + + class CustomBoringDataModule(BoringDataModule): + + def train_dataloader(self): + return DataLoader(self.random_train, batch_size=4) + + def val_dataloader(self): + return DataLoader(self.random_val, batch_size=4) + + def test_dataloader(self): + return DataLoader(self.random_test, batch_size=4) + + tutils.reset_seed() + prediction_file = Path(tmpdir) / 'predictions.pt' + + dm = BoringDataModule() + model = CustomBoringModel() + model.test_step_end = None + model.test_epoch_end = None + model.test_end = None + + model.test_option = test_option + model.prediction_file = prediction_file.as_posix() + + if prediction_file.exists(): + prediction_file.unlink() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + deterministic=True, + gpus=gpus, + ) + + # Prediction file shouldn't exist yet because we haven't done anything + assert not prediction_file.exists() + + if do_train: + result = trainer.fit(model, dm) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert result + result = trainer.test(datamodule=dm) + # TODO: add end-to-end test + # assert result[0]['test_loss'] < 0.6 + else: + result = trainer.test(model, datamodule=dm) + + # check prediction file now exists and is of expected length + assert prediction_file.exists() + predictions = torch.load(prediction_file) + assert len(predictions) == len(dm.random_test) + + +def test_result_gather_stack(): + """ Test that tensors get concatenated when they all have the same shape. """ + outputs = [ + { + "foo": torch.zeros(4, 5) + }, + { + "foo": torch.zeros(4, 5) + }, + { + "foo": torch.zeros(4, 5) + }, + ] + result = Result.gather(outputs) + assert isinstance(result["foo"], torch.Tensor) + assert list(result["foo"].shape) == [12, 5] + + +def test_result_gather_concatenate(): + """ Test that tensors get concatenated when they have varying size in first dimension. """ + outputs = [ + { + "foo": torch.zeros(4, 5) + }, + { + "foo": torch.zeros(8, 5) + }, + { + "foo": torch.zeros(3, 5) + }, + ] + result = Result.gather(outputs) + assert isinstance(result["foo"], torch.Tensor) + assert list(result["foo"].shape) == [15, 5] + + +def test_result_gather_scalar(): + """ Test that 0-dim tensors get gathered and stacked correctly. """ + outputs = [ + { + "foo": torch.tensor(1) + }, + { + "foo": torch.tensor(2) + }, + { + "foo": torch.tensor(3) + }, + ] + result = Result.gather(outputs) + assert isinstance(result["foo"], torch.Tensor) + assert list(result["foo"].shape) == [3] + + +def test_result_gather_different_shapes(): + """ Test that tensors of varying shape get gathered into a list. """ + outputs = [ + { + "foo": torch.tensor(1) + }, + { + "foo": torch.zeros(2, 3) + }, + { + "foo": torch.zeros(1, 2, 3) + }, + ] + result = Result.gather(outputs) + expected = [torch.tensor(1), torch.zeros(2, 3), torch.zeros(1, 2, 3)] + assert isinstance(result["foo"], list) + assert all(torch.eq(r, e).all() for r, e in zip(result["foo"], expected)) + + +def test_result_gather_mixed_types(): + """ Test that a collection of mixed types gets gathered into a list. """ + outputs = [ + { + "foo": 1.2 + }, + { + "foo": ["bar", None] + }, + { + "foo": torch.tensor(1) + }, + ] + result = Result.gather(outputs) + expected = [1.2, ["bar", None], torch.tensor(1)] + assert isinstance(result["foo"], list) + assert result["foo"] == expected + + +def test_result_retrieve_last_logged_item(): + result = Result() + result.log('a', 5., on_step=True, on_epoch=True) + assert result['a_epoch'] == 5. + assert result['a_step'] == 5. + assert result['a'] == 5. diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py new file mode 100644 index 00000000000000..ccfae3ec8dcf2c --- /dev/null +++ b/tests/deprecated_api/__init__.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in vX.Y.Z""" +import sys +from contextlib import contextmanager +from typing import Optional + +import pytest + + +def _soft_unimport_module(str_module): + # once the module is imported e.g with parsing with pytest it lives in memory + if str_module in sys.modules: + del sys.modules[str_module] + + +@contextmanager +def no_deprecated_call(match: Optional[str] = None): + with pytest.warns(None) as record: + yield + try: + w = record.pop(DeprecationWarning) + if match is not None and match not in str(w.message): + return + except AssertionError: + # no DeprecationWarning raised + return + raise AssertionError(f"`DeprecationWarning` was raised: {w}") diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py new file mode 100644 index 00000000000000..99e1b31f6edad9 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-4.py @@ -0,0 +1,228 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.4.0""" + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.overrides.data_parallel import ( + LightningDataParallel, + LightningDistributedDataParallel, + LightningParallelModule, +) +from pytorch_lightning.overrides.distributed import LightningDistributedModule +from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.plugins.environments import LightningEnvironment +from tests.deprecated_api import _soft_unimport_module +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +def test_v1_4_0_deprecated_trainer_attributes(): + with pytest.deprecated_call(match="will be removed in v1.4."): + trainer = Trainer() + _ = trainer.accelerator_backend + assert trainer.accelerator == trainer.accelerator_backend + + +def test_v1_4_0_deprecated_trainer_methods(): + with pytest.deprecated_call(match='will be removed in v1.4'): + trainer = Trainer() + _ = trainer.get_model() + assert trainer.get_model() == trainer.lightning_module + + +def test_v1_4_0_deprecated_imports(): + _soft_unimport_module('pytorch_lightning.utilities.argparse_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.argparse_utils import from_argparse_args # noqa: F811 F401 + + _soft_unimport_module('pytorch_lightning.utilities.model_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.model_utils import is_overridden # noqa: F811 F401 + + _soft_unimport_module('pytorch_lightning.utilities.warning_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.warning_utils import WarningCache # noqa: F811 F401 + + _soft_unimport_module('pytorch_lightning.utilities.xla_device_utils') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401 + + +def test_v1_4_0_deprecated_trainer_device_distrib(): + """Test that Trainer attributes works fine.""" + trainer = Trainer() + trainer.accelerator_connector._distrib_type = None + trainer.accelerator_connector._device_type = None + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.on_cpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_cpu + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.on_gpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_gpu + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.on_tpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_tpu + trainer.accelerator_connector._device_type = None + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_tpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_tpu + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_dp = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_dp + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_ddp = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_ddp + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_ddp2 = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_ddp2 + + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_horovod = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_horovod + + +def test_v1_4_0_deprecated_metrics(): + from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes + with pytest.deprecated_call(match='will be removed in v1.4'): + stat_scores_multiple_classes(pred=torch.tensor([0, 1]), target=torch.tensor([0, 1])) + + from pytorch_lightning.metrics.functional.classification import iou + with pytest.deprecated_call(match='will be removed in v1.4'): + iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import recall + with pytest.deprecated_call(match='will be removed in v1.4'): + recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import precision + with pytest.deprecated_call(match='will be removed in v1.4'): + precision(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import precision_recall + with pytest.deprecated_call(match='will be removed in v1.4'): + precision_recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + from pytorch_lightning.metrics.functional.classification import auc + with pytest.deprecated_call(match='will be removed in v1.4'): + auc(torch.rand(10, ).sort().values, torch.rand(10, )) + + from pytorch_lightning.metrics.functional.classification import auroc + with pytest.deprecated_call(match='will be removed in v1.4'): + auroc(torch.rand(10, ), torch.randint(0, 2, (10, ))) + + from pytorch_lightning.metrics.functional.classification import multiclass_auroc + with pytest.deprecated_call(match='will be removed in v1.4'): + multiclass_auroc(torch.rand(20, 5).softmax(dim=-1), torch.randint(0, 5, (20, )), num_classes=5) + + +class CustomDDPPlugin(DDPSpawnPlugin): + + def configure_ddp(self): + # old, deprecated implementation + with pytest.deprecated_call( + match='`LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4.' + ): + self._model = LightningDistributedDataParallel( + module=self.lightning_module, + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + assert isinstance(self.model, torch.nn.parallel.DistributedDataParallel) + assert isinstance(self.model.module, LightningDistributedModule) + + +@RunIf(min_gpus=2, skip_windows=True) +def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + gpus=2, + accelerator="ddp_spawn", + plugins=[ + CustomDDPPlugin( + parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], + cluster_environment=LightningEnvironment(), + ) + ] + ) + trainer.fit(model) + + +@RunIf(min_gpus=1) +def test_v1_4_0_deprecated_lightning_data_parallel(): + model = BoringModel() + with pytest.deprecated_call(match="`LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4."): + dp_model = LightningDataParallel(model, device_ids=[0]) + assert isinstance(dp_model, torch.nn.DataParallel) + assert isinstance(dp_model.module, LightningParallelModule) + + +def test_v1_4_0_deprecated_manual_optimization_optimizer(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, *_, **kwargs): + opt = self.optimizers() + output = self.layer(batch) + loss = self.loss(batch, output) + self.manual_backward(loss, opt) + + @property + def automatic_optimization(self): + return False + + model = TestModel() + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + with pytest.deprecated_call( + match="`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" + ): + trainer.fit(model) + + +def test_v1_4_0_deprecated_checkpoint_on(tmpdir): + from pytorch_lightning.callbacks.model_checkpoint import warning_cache + warning_cache.clear() + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log("val_loss", -batch_idx) + return super().training_step(batch, batch_idx) + + trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1) + + with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"): + trainer.fit(TestModel()) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py new file mode 100644 index 00000000000000..fc3fe3112e71e4 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-5.py @@ -0,0 +1,220 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.5.0""" +from unittest import mock + +import pytest +from torch import optim + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache +from tests.deprecated_api import no_deprecated_call +from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call + + +def test_v1_5_0_model_checkpoint_save_checkpoint(): + model_ckpt = ModelCheckpoint() + model_ckpt.save_function = lambda *_, **__: None + with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"): + model_ckpt.save_checkpoint(Trainer(), object()) + + +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_v1_5_0_wandb_unused_sync_step(tmpdir): + with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): + WandbLogger(sync_step=True) + + +def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): + + class OldSignature(Callback): + + def on_save_checkpoint(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer_kwargs = { + "default_root_dir": tmpdir, + "checkpoint_callback": False, + "max_epochs": 1, + } + filepath = tmpdir / "test.ckpt" + + trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) + trainer.fit(model) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.save_checkpoint(filepath) + + class NewSignature(Callback): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + ... + + class ValidSignature1(Callback): + + def on_save_checkpoint(self, trainer, *args): + ... + + class ValidSignature2(Callback): + + def on_save_checkpoint(self, *args): + ... + + trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] + with no_warning_call(DeprecationWarning): + trainer.save_checkpoint(filepath) + + +def test_v1_5_0_legacy_profiler_argument(): + with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): + PyTorchProfiler(profiled_functions=[]) + + +def test_v1_5_0_running_sanity_check(): + trainer = Trainer() + with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): + assert not trainer.running_sanity_check + + +def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir): + + class OldSignatureModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx, optimizer_idx): + assert optimizer_idx is not None + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)] + + model = OldSignatureModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) + + with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): + trainer.fit(model) + + +def test_v1_5_0_model_checkpoint_period(tmpdir): + with no_warning_call(DeprecationWarning): + ModelCheckpoint(dirpath=tmpdir) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + ModelCheckpoint(dirpath=tmpdir, period=1) + + +def test_v1_5_0_old_on_validation_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + class OldSignatureModel(BoringModel): + + def on_validation_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + class NewSignatureModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + +def test_v1_5_0_old_on_test_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) + + class OldSignatureModel(BoringModel): + + def on_test_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) + + class NewSignatureModel(BoringModel): + + def on_test_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) + + +@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_v1_5_0_profiler_output_filename(tmpdir, cls): + filepath = str(tmpdir / "test.txt") + with pytest.deprecated_call(match="`output_filename` parameter has been removed"): + profiler = cls(output_filename=filepath) + assert profiler.dirpath == tmpdir + assert profiler.filename == "test" diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 00000000000000..e6fa5cfa707953 --- /dev/null +++ b/tests/helpers/__init__.py @@ -0,0 +1,2 @@ +from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401 +from tests.helpers.datasets import TrialMNIST # noqa: F401 diff --git a/tests/helpers/advanced_models.py b/tests/helpers/advanced_models.py new file mode 100644 index 00000000000000..2b0146e1ee0998 --- /dev/null +++ b/tests/helpers/advanced_models.py @@ -0,0 +1,230 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS +from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST + + +class Generator(nn.Module): + + def __init__(self, latent_dim: int, img_shape: tuple): + super().__init__() + self.img_shape = img_shape + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), + nn.Tanh(), + ) + + def forward(self, z): + img = self.model(z) + img = img.view(img.size(0), *self.img_shape) + return img + + +class Discriminator(nn.Module): + + def __init__(self, img_shape: tuple): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(int(np.prod(img_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + nn.Sigmoid(), + ) + + def forward(self, img): + img_flat = img.view(img.size(0), -1) + validity = self.model(img_flat) + + return validity + + +class BasicGAN(LightningModule): + """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" + + def __init__( + self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs + ): + super().__init__() + self.hidden_dim = hidden_dim + self.learning_rate = learning_rate + self.b1 = b1 + self.b2 = b2 + + # networks + mnist_shape = (1, 28, 28) + self.generator = Generator(latent_dim=self.hidden_dim, img_shape=mnist_shape) + self.discriminator = Discriminator(img_shape=mnist_shape) + + # cache for generated images + self.generated_imgs = None + self.last_imgs = None + + self.example_input_array = torch.rand(2, self.hidden_dim) + + def forward(self, z): + return self.generator(z) + + def adversarial_loss(self, y_hat, y): + return F.binary_cross_entropy(y_hat, y) + + def training_step(self, batch, batch_idx, optimizer_idx=None): + imgs, _ = batch + self.last_imgs = imgs + + # train generator + if optimizer_idx == 0: + # sample noise + z = torch.randn(imgs.shape[0], self.hidden_dim) + z = z.type_as(imgs) + + # generate images + self.generated_imgs = self(z) + + # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) + tqdm_dict = {'g_loss': g_loss} + output = OrderedDict({ + 'loss': g_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict, + }) + return output + + # train discriminator + if optimizer_idx == 1: + # Measure discriminator's ability to classify real from generated samples + + # how well can it label as real? + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + real_loss = self.adversarial_loss(self.discriminator(imgs), valid) + + # how well can it label as fake? + fake = torch.zeros(imgs.size(0), 1) + fake = fake.type_as(fake) + + fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake) + + # discriminator loss is the average of these + d_loss = (real_loss + fake_loss) / 2 + tqdm_dict = {'d_loss': d_loss} + output = OrderedDict({ + 'loss': d_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict, + }) + return output + + def configure_optimizers(self): + lr = self.learning_rate + b1 = self.b1 + b2 = self.b2 + + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + return [opt_g, opt_d], [] + + def train_dataloader(self): + return DataLoader(TrialMNIST(root=PATH_DATASETS, train=True, download=True), batch_size=16) + + +class ParityModuleRNN(LightningModule): + + def __init__(self): + super().__init__() + self.rnn = nn.LSTM(10, 20, batch_first=True) + self.linear_out = nn.Linear(in_features=20, out_features=5) + self.example_input_array = torch.rand(2, 3, 10) + + def forward(self, x): + seq, last = self.rnn(x) + return self.linear_out(seq) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.mse_loss(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=30) + + +class ParityModuleMNIST(LightningModule): + + def __init__(self): + super().__init__() + self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) + self.c_d1_bn = nn.BatchNorm1d(128) + self.c_d1_drop = nn.Dropout(0.3) + self.c_d2 = nn.Linear(in_features=128, out_features=10) + self.example_input_array = torch.rand(2, 1, 28, 28) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + x = self.c_d2(x) + return x + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(MNIST( + root=PATH_DATASETS, + train=True, + download=True, + ), batch_size=128, num_workers=1) diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py new file mode 100644 index 00000000000000..6ef2518bbef11b --- /dev/null +++ b/tests/helpers/boring_model.py @@ -0,0 +1,170 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch.utils.data import DataLoader, Dataset, Subset + +from pytorch_lightning import LightningDataModule, LightningModule + + +class RandomDictDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + a = self.data[index] + b = a + 2 + return {'a': a, 'b': b} + + def __len__(self): + return self.len + + +class RandomDictStringDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return {"id": str(index), "x": self.data[index]} + + def __len__(self): + return self.len + + +class RandomDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + + def __init__(self): + """ + Testing PL Module + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x['x'] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class BoringDataModule(LightningDataModule): + + def __init__(self, data_dir: str = './'): + super().__init__() + self.data_dir = data_dir + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + def prepare_data(self): + self.random_full = RandomDataset(32, 192) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self.random_train = Subset(self.random_full, indices=range(64)) + self.dims = self.random_train[0].shape + + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 128)) + + if stage == "test" or stage is None: + self.random_test = Subset(self.random_full, indices=range(128, 192)) + self.dims = getattr(self, "dims", self.random_test[0].shape) + + def train_dataloader(self): + return DataLoader(self.random_train) + + def val_dataloader(self): + return DataLoader(self.random_val) + + def test_dataloader(self): + return DataLoader(self.random_test) diff --git a/tests/helpers/dataloaders.py b/tests/helpers/dataloaders.py new file mode 100644 index 00000000000000..fa6009afece351 --- /dev/null +++ b/tests/helpers/dataloaders.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom dataloaders for testing""" + + +class CustomInfDataloader: + + def __init__(self, dataloader): + self.dataloader = dataloader + self.iter = iter(dataloader) + self.count = 0 + self.dataloader.num_workers = 0 # reduce chance for hanging pytest + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count >= 50: + raise StopIteration + self.count = self.count + 1 + try: + return next(self.iter) + except StopIteration: + self.iter = iter(self.dataloader) + return next(self.iter) + + +class CustomNotImplementedErrorDataloader(CustomInfDataloader): + + def __len__(self): + """raise NotImplementedError""" + raise NotImplementedError + + def __next__(self): + if self.count >= 2: + raise StopIteration + self.count = self.count + 1 + try: + return next(self.iter) + except StopIteration: + self.iter = iter(self.dataloader) + return next(self.iter) diff --git a/tests/helpers/datamodules.py b/tests/helpers/datamodules.py new file mode 100644 index 00000000000000..12ec16261159db --- /dev/null +++ b/tests/helpers/datamodules.py @@ -0,0 +1,114 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.utilities import _module_available +from tests.helpers.datasets import MNIST, SklearnDataset, TrialMNIST + +_SKLEARN_AVAILABLE = _module_available("sklearn") +if _SKLEARN_AVAILABLE: + from sklearn.datasets import make_classification, make_regression + from sklearn.model_selection import train_test_split + + +class MNISTDataModule(LightningDataModule): + + def __init__(self, data_dir: str = "./", batch_size: int = 32, use_trials: bool = False) -> None: + super().__init__() + + self.data_dir = data_dir + self.batch_size = batch_size + + # TrialMNIST is a constrained MNIST dataset + self.dataset_cls = TrialMNIST if use_trials else MNIST + + # self.dims is returned when you call dm.size() + # Setting default dims here because we know them. + # Could optionally be assigned dynamically in dm.setup() + self.dims = (1, 28, 28) + + def prepare_data(self): + # download only + self.dataset_cls(self.data_dir, train=True, download=True) + self.dataset_cls(self.data_dir, train=False, download=True) + + def setup(self, stage: Optional[str] = None): + # TODO: need to split using random_split once updated to torch >= 1.6 + if stage == "fit" or stage is None: + self.mnist_train = self.dataset_cls(self.data_dir, train=True) + if stage == "test" or stage is None: + self.mnist_test = self.dataset_cls(self.data_dir, train=False) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=False) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False) + + +class SklearnDataModule(LightningDataModule): + + def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 10): + super().__init__() + self.batch_size = batch_size + self._x, self._y = sklearn_dataset + self._split_data() + self._x_type = x_type + self._y_type = y_type + + def _split_data(self): + self.x_train, self.x_test, self.y_train, self.y_test = \ + train_test_split(self._x, self._y, test_size=0.20, random_state=42) + self.x_train, self.x_valid, self.y_train, self.y_valid = \ + train_test_split(self.x_train, self.y_train, test_size=0.40, random_state=42) + + def train_dataloader(self): + return DataLoader( + SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type), batch_size=self.batch_size + ) + + def val_dataloader(self): + return DataLoader( + SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size + ) + + def test_dataloader(self): + return DataLoader( + SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size + ) + + @property + def sample(self): + return torch.tensor([self._x[0]], dtype=self._x_type) + + +class ClassifDataModule(SklearnDataModule): + + def __init__(self, num_features=32, length=800, num_classes=3, batch_size=10): + data = make_classification( + n_samples=length, n_features=num_features, n_classes=num_classes, n_clusters_per_class=1, random_state=42 + ) + super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size) + + +class RegressDataModule(SklearnDataModule): + + def __init__(self, num_features=16, length=800, batch_size=10): + x, y = make_regression(n_samples=length, n_features=num_features, random_state=42) + y = [[v] for v in y] + super().__init__((x, y), x_type=torch.float32, y_type=torch.float32, batch_size=batch_size) diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py new file mode 100644 index 00000000000000..77035796ca3b18 --- /dev/null +++ b/tests/helpers/datasets.py @@ -0,0 +1,223 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import random +import time +import urllib.request +from typing import Optional, Sequence, Tuple + +import torch +from torch import Tensor +from torch.utils.data import Dataset + + +class MNIST(Dataset): + """ + Customized `MNIST `_ dataset for testing Pytorch Lightning + without the torchvision dependency. + + Part of the code was copied from + https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py + + Args: + root: Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + normalize: mean and std deviation of the MNIST dataset. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + Examples: + >>> dataset = MNIST(".", download=True) + >>> len(dataset) + 60000 + >>> torch.bincount(dataset.targets) + tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]) + """ + + RESOURCES = ( + "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt", + "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt", + ) + + TRAIN_FILE_NAME = 'training.pt' + TEST_FILE_NAME = 'test.pt' + cache_folder_name = 'complete' + + def __init__( + self, + root: str, + train: bool = True, + normalize: tuple = (0.1307, 0.3081), + download: bool = True, + **kwargs, + ): + super().__init__() + self.root = root + self.train = train # training set or test set + self.normalize = normalize + + self.prepare_data(download) + + data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME + self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) + + def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + img = self.data[idx].float().unsqueeze(0) + target = int(self.targets[idx]) + + if self.normalize is not None and len(self.normalize) == 2: + img = self.normalize_tensor(img, *self.normalize) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def cached_folder_path(self) -> str: + return os.path.join(self.root, 'MNIST', self.cache_folder_name) + + def _check_exists(self, data_folder: str) -> bool: + existing = True + for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): + existing = existing and os.path.isfile(os.path.join(data_folder, fname)) + return existing + + def prepare_data(self, download: bool = True): + if download and not self._check_exists(self.cached_folder_path): + self._download(self.cached_folder_path) + if not self._check_exists(self.cached_folder_path): + raise RuntimeError('Dataset not found.') + + def _download(self, data_folder: str) -> None: + os.makedirs(data_folder) + for url in self.RESOURCES: + logging.info(f'Downloading {url}') + fpath = os.path.join(data_folder, os.path.basename(url)) + urllib.request.urlretrieve(url, fpath) + + @staticmethod + def _try_load(path_data, trials: int = 30, delta: float = 1.): + """Resolving loading from the same time from multiple concurrent processes.""" + res, exception = None, None + assert trials, "at least some trial has to be set" + assert os.path.isfile(path_data), f'missing file: {path_data}' + for _ in range(trials): + try: + res = torch.load(path_data) + # todo: specify the possible exception + except Exception as e: + exception = e + time.sleep(delta * random.random()) + else: + break + if exception is not None: + # raise the caught exception + raise exception + return res + + @staticmethod + def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor: + mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) + return tensor.sub(mean).div(std) + + +class TrialMNIST(MNIST): + """Constrained MNIST dataset + + Args: + num_samples: number of examples per selected class/digit + digits: list selected MNIST digits/classes + kwargs: Same as MNIST + + Examples: + >>> dataset = TrialMNIST(".", download=True) + >>> len(dataset) + 300 + >>> sorted(set([d.item() for d in dataset.targets])) + [0, 1, 2] + >>> torch.bincount(dataset.targets) + tensor([100, 100, 100]) + """ + + def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): + # number of examples per class + self.num_samples = num_samples + # take just a subset of MNIST dataset + self.digits = sorted(digits) if digits else list(range(10)) + + self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}" + + super().__init__(root, normalize=(0.5, 1.0), **kwargs) + + @staticmethod + def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence): + classes = {d: 0 for d in digits} + indexes = [] + for idx, target in enumerate(full_targets): + label = target.item() + if classes.get(label, float('inf')) >= num_samples: + continue + indexes.append(idx) + classes[label] += 1 + if all(classes[k] >= num_samples for k in classes): + break + data = full_data[indexes] + targets = full_targets[indexes] + return data, targets + + def _download(self, data_folder: str) -> None: + super()._download(data_folder) + for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME): + path_fname = os.path.join(self.cached_folder_path, fname) + assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}' + data, targets = self._try_load(path_fname) + data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) + torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) + + +class AverageDataset(Dataset): + + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + +class SklearnDataset(Dataset): + + def __init__(self, x, y, x_type, y_type): + self.x = x + self.y = y + self._x_type = x_type + self._y_type = y_type + + def __getitem__(self, idx): + return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type) + + def __len__(self): + return len(self.y) diff --git a/tests/helpers/deterministic_model.py b/tests/helpers/deterministic_model.py new file mode 100644 index 00000000000000..20830a1bc6fc23 --- /dev/null +++ b/tests/helpers/deterministic_model.py @@ -0,0 +1,233 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import DistributedType + + +class DeterministicModel(LightningModule): + + def __init__(self, weights=None): + super().__init__() + + self.training_step_called = False + self.training_step_end_called = False + self.training_epoch_end_called = False + + self.validation_step_called = False + self.validation_step_end_called = False + self.validation_epoch_end_called = False + + self.assert_backward = True + + self.l1 = nn.Linear(2, 3, bias=False) + if weights is None: + weights = torch.tensor([[4, 3, 5], [10, 11, 13]]).float() + p = torch.nn.Parameter(weights, requires_grad=True) + self.l1.weight = p + + def forward(self, x): + return self.l1(x) + + def step(self, batch, batch_idx): + x = batch + bs = x.size(0) + y_hat = self.l1(x) + + test_hat = y_hat.cpu().detach() + assert torch.all(test_hat[:, 0] == 15.0) + assert torch.all(test_hat[:, 1] == 42.0) + out = y_hat.sum() + assert out == (42.0 * bs) + (15.0 * bs) + + return out + + def count_num_graphs(self, result, num_graphs=0): + for k, v in result.items(): + if isinstance(v, torch.Tensor) and v.grad_fn is not None: + num_graphs += 1 + if isinstance(v, dict): + num_graphs += self.count_num_graphs(v) + + return num_graphs + + # --------------------------- + # scalar return + # --------------------------- + def training_step__scalar_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + self.training_step_called = True + return acc + + def training_step_end__scalar(self, output): + self.training_step_end_called = True + + # make sure loss has the grad + assert isinstance(output, torch.Tensor) + assert output.grad_fn is not None + + # make sure nothing else has grads + assert self.count_num_graphs({'loss': output}) == 1 + + assert output == 171 + + return output + + def training_epoch_end__scalar(self, outputs): + """ + There should be an array of scalars without graphs that are all 171 (4 of them) + """ + self.training_epoch_end_called = True + + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): + pass + else: + # only saw 4 batches + assert len(outputs) == 4 + for batch_out in outputs: + batch_out = batch_out['loss'] + assert batch_out == 171 + assert batch_out.grad_fn is None + assert isinstance(batch_out, torch.Tensor) + + # -------------------------- + # dictionary returns + # -------------------------- + def training_step__dict_return(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + + self.training_step_called = True + return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_test': torch.tensor(549).type_as(acc)} + + def training_step_end__dict(self, output): + self.training_step_end_called = True + + # make sure loss has the grad + assert 'loss' in output + assert output['loss'].grad_fn is not None + + # make sure nothing else has grads + assert self.count_num_graphs(output) == 1 + + # make sure the other keys are there + assert 'log_acc1' in output + assert 'log_acc2' in output + assert 'pbar_acc1' in output + assert 'pbar_acc2' in output + + logs = {'log_acc1': output['log_acc1'] + 2, 'log_acc2': output['log_acc2'] + 2} + pbar = {'pbar_acc1': output['pbar_acc1'] + 2, 'pbar_acc2': output['pbar_acc2'] + 2} + + acc = output['loss'] + return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_end': acc} + + def validation_step__no_return(self, batch, batch_idx): + self.validation_step_called = True + self.step(batch, batch_idx) + + def validation_step__scalar_return(self, batch, batch_idx): + self.validation_step_called = True + acc = self.step(batch, batch_idx) + return acc + + def validation_step__dummy_dict_return(self, batch, batch_idx): + self.validation_step_called = True + acc = self.step(batch, batch_idx) + return {'some': acc, 'value': 'a'} + + def validation_step__dict_return(self, batch, batch_idx): + self.validation_step_called = True + acc = self.step(batch, batch_idx) + + logs = {'log_acc1': torch.tensor(12 + batch_idx).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} + return {'val_loss': acc, 'log': logs, 'progress_bar': pbar} + + def validation_step_end__no_return(self, val_step_output): + assert len(val_step_output) == 3 + assert val_step_output['val_loss'] == 171 + assert val_step_output['log']['log_acc1'] >= 12 + assert val_step_output['progress_bar']['pbar_acc1'] == 17 + self.validation_step_end_called = True + + def validation_step_end(self, val_step_output): + assert len(val_step_output) == 3 + assert val_step_output['val_loss'] == 171 + assert val_step_output['log']['log_acc1'] >= 12 + assert val_step_output['progress_bar']['pbar_acc1'] == 17 + self.validation_step_end_called = True + + val_step_output['val_step_end'] = torch.tensor(1802) + + return val_step_output + + def validation_epoch_end(self, outputs): + assert len(outputs) == self.trainer.num_val_batches[0] + + for i, out in enumerate(outputs): + assert out['log']['log_acc1'] >= 12 + i + + self.validation_epoch_end_called = True + + result = outputs[-1] + result['val_epoch_end'] = torch.tensor(1233) + return result + + # ----------------------------- + # DATA + # ----------------------------- + def train_dataloader(self): + return DataLoader(DummyDataset(), batch_size=3, shuffle=False) + + def val_dataloader(self): + return DataLoader(DummyDataset(), batch_size=3, shuffle=False) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0) + + def configure_optimizers__lr_on_plateau_epoch(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0) + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + scheduler = {'scheduler': lr_scheduler, 'interval': 'epoch', 'monitor': 'epoch_end_log_1'} + return [optimizer], [scheduler] + + def configure_optimizers__lr_on_plateau_step(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0) + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + scheduler = {'scheduler': lr_scheduler, 'interval': 'step', 'monitor': 'pbar_acc1'} + return [optimizer], [scheduler] + + def backward(self, loss, optimizer, optimizer_idx): + if self.assert_backward: + if self.trainer.precision == 16: + assert loss > 171 * 1000 + else: + assert loss == 171.0 + + super().backward(loss, optimizer, optimizer_idx) + + +class DummyDataset(Dataset): + + def __len__(self): + return 12 + + def __getitem__(self, idx): + return torch.tensor([0.5, 1.0, 2.0]) diff --git a/tests/helpers/imports.py b/tests/helpers/imports.py new file mode 100644 index 00000000000000..4db9c00d45eabb --- /dev/null +++ b/tests/helpers/imports.py @@ -0,0 +1,8 @@ +import operator + +from pytorch_lightning.utilities.imports import _compare_version + +if _compare_version("torchtext", operator.ge, "0.9.0"): + from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 +else: + from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py new file mode 100644 index 00000000000000..403bcdfee8c1d7 --- /dev/null +++ b/tests/helpers/pipelines.py @@ -0,0 +1,111 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.metrics.functional import accuracy +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import DistributedType +from tests.helpers import BoringModel +from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed + + +def run_model_test_without_loggers( + trainer_options: dict, model: LightningModule, data: LightningDataModule = None, min_acc: float = 0.50 +): + reset_seed() + + # fit model + trainer = Trainer(**trainer_options) + trainer.fit(model, datamodule=data) + + # correct result and ok accuracy + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + model2 = load_model_from_checkpoint(trainer.logger, trainer.checkpoint_callback.best_model_path, type(model)) + + # test new model accuracy + test_loaders = model2.test_dataloader() if not data else data.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + if not isinstance(model2, BoringModel): + for dataloader in test_loaders: + run_prediction_eval_model_template(model2, dataloader, min_acc=min_acc) + + +def run_model_test( + trainer_options, + model: LightningModule, + data: LightningDataModule = None, + on_gpu: bool = True, + version=None, + with_hpc: bool = True, + min_acc: float = 0.25 +): + reset_seed() + save_dir = trainer_options['default_root_dir'] + + # logger file to get meta + logger = get_default_logger(save_dir, version=version) + trainer_options.update(logger=logger) + trainer = Trainer(**trainer_options) + initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()]) + trainer.fit(model, datamodule=data) + post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()]) + + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + # Check that the model is actually changed post-training + change_ratio = torch.norm(initial_values - post_train_values) + assert change_ratio > 0.1, f"the model is changed of {change_ratio}" + + # test model loading + pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) + + # test new model accuracy + test_loaders = model.test_dataloader() if not data else data.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + if not isinstance(model, BoringModel): + for dataloader in test_loaders: + run_prediction_eval_model_template(model, dataloader, min_acc=min_acc) + + if with_hpc: + if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): + # on hpc this would work fine... but need to hack it for the purpose of the test + trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ + trainer.init_optimizers(pretrained_model) + + # test HPC saving + trainer.checkpoint_connector.hpc_save(save_dir, logger) + # test HPC loading + checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) + trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + + +@torch.no_grad() +def run_prediction_eval_model_template(trained_model, dataloader, min_acc=0.50): + # run prediction on 1 batch + trained_model.cpu() + trained_model.eval() + + batch = next(iter(dataloader)) + x, y = batch + x = x.flatten(1) + + y_hat = trained_model(x) + acc = accuracy(y_hat.cpu(), y.cpu(), top_k=2).item() + + assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})" diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py new file mode 100644 index 00000000000000..5483e33d9cddb4 --- /dev/null +++ b/tests/helpers/runif.py @@ -0,0 +1,184 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from distutils.version import LooseVersion +from typing import Optional + +import pytest +import torch +from pkg_resources import get_distribution + +from pytorch_lightning.utilities import ( + _APEX_AVAILABLE, + _DEEPSPEED_AVAILABLE, + _FAIRSCALE_AVAILABLE, + _FAIRSCALE_PIPE_AVAILABLE, + _HOROVOD_AVAILABLE, + _NATIVE_AMP_AVAILABLE, + _RPC_AVAILABLE, + _TORCH_QUANTIZE_AVAILABLE, + _TPU_AVAILABLE, +) + +try: + from horovod.common.util import nccl_built + nccl_built() +except (ImportError, ModuleNotFoundError, AttributeError): + _HOROVOD_NCCL_AVAILABLE = False +finally: + _HOROVOD_NCCL_AVAILABLE = True + + +class RunIf: + """ + RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: + + @RunIf(min_torch="0.0") + @pytest.mark.parametrize("arg1", [1, 2.0]) + def test_wrapper(arg1): + assert arg1 > 0.0 + """ + + def __new__( + self, + *args, + min_gpus: int = 0, + min_torch: Optional[str] = None, + max_torch: Optional[str] = None, + min_python: Optional[str] = None, + quantization: bool = False, + amp_apex: bool = False, + amp_native: bool = False, + tpu: bool = False, + horovod: bool = False, + horovod_nccl: bool = False, + skip_windows: bool = False, + special: bool = False, + rpc: bool = False, + fairscale: bool = False, + fairscale_pipe: bool = False, + deepspeed: bool = False, + **kwargs + ): + """ + Args: + args: native pytest.mark.skipif arguments + min_gpus: min number of gpus required to run test + min_torch: minimum pytorch version to run test + max_torch: maximum pytorch version to run test + min_python: minimum python version required to run test + quantization: if `torch.quantization` package is required to run test + amp_apex: NVIDIA Apex is installed + amp_native: if native PyTorch native AMP is supported + tpu: if TPU is available + horovod: if Horovod is installed + horovod_nccl: if Horovod is installed with NCCL support + skip_windows: skip test for Windows platform (typically fo some limited torch functionality) + special: running in special mode, outside pytest suit + rpc: requires Remote Procedure Call (RPC) + fairscale: if `fairscale` module is required to run the test + deepspeed: if `deepspeed` module is required to run the test + kwargs: native pytest.mark.skipif keyword arguments + """ + conditions = [] + reasons = [] + + if min_gpus: + conditions.append(torch.cuda.device_count() < min_gpus) + reasons.append(f"GPUs>={min_gpus}") + + if min_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version < LooseVersion(min_torch)) + reasons.append(f"torch>={min_torch}") + + if max_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version >= LooseVersion(max_torch)) + reasons.append(f"torch<{max_torch}") + + if min_python: + py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + conditions.append(py_version < LooseVersion(min_python)) + reasons.append(f"python>={min_python}") + + if quantization: + _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines + conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) + reasons.append("PyTorch quantization") + + if amp_native: + conditions.append(not _NATIVE_AMP_AVAILABLE) + reasons.append("native AMP") + + if amp_apex: + conditions.append(not _APEX_AVAILABLE) + reasons.append("NVIDIA Apex") + + if skip_windows: + conditions.append(sys.platform == "win32") + reasons.append("unimplemented on Windows") + + if tpu: + conditions.append(not _TPU_AVAILABLE) + reasons.append("TPU") + + if horovod: + conditions.append(not _HOROVOD_AVAILABLE) + reasons.append("Horovod") + + if horovod_nccl: + conditions.append(not _HOROVOD_NCCL_AVAILABLE) + reasons.append("Horovod with NCCL") + + if special: + env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') + conditions.append(env_flag != '1') + reasons.append("Special execution") + + if rpc: + conditions.append(not _RPC_AVAILABLE) + reasons.append("RPC") + + if fairscale: + conditions.append(not _FAIRSCALE_AVAILABLE) + reasons.append("Fairscale") + + if fairscale_pipe: + conditions.append(not _FAIRSCALE_PIPE_AVAILABLE) + reasons.append("Fairscale Pipe") + + if deepspeed: + conditions.append(not _DEEPSPEED_AVAILABLE) + reasons.append("Deepspeed") + + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] + return pytest.mark.skipif( + *args, + condition=any(conditions), + reason=f"Requires: [{' + '.join(reasons)}]", + **kwargs, + ) + + +@RunIf(min_torch="99") +def test_always_skip(): + exit(1) + + +@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) +@RunIf(min_torch="0.0") +def test_wrapper(arg1: float): + assert arg1 > 0.0 diff --git a/tests/helpers/simple_models.py b/tests/helpers/simple_models.py new file mode 100644 index 00000000000000..1abeb1f00206a7 --- /dev/null +++ b/tests/helpers/simple_models.py @@ -0,0 +1,120 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +from torch import nn + +from pytorch_lightning import LightningModule +from pytorch_lightning.metrics import Accuracy, MeanSquaredError + + +class ClassificationModel(LightningModule): + + def __init__(self, lr=0.01): + super().__init__() + + self.lr = lr + for i in range(3): + setattr(self, f"layer_{i}", nn.Linear(32, 32)) + setattr(self, f"layer_{i}a", torch.nn.ReLU()) + setattr(self, "layer_end", nn.Linear(32, 3)) + + self.train_acc = Accuracy() + self.valid_acc = Accuracy() + self.test_acc = Accuracy() + + def forward(self, x): + x = self.layer_0(x) + x = self.layer_0a(x) + x = self.layer_1(x) + x = self.layer_1a(x) + x = self.layer_2(x) + x = self.layer_2a(x) + x = self.layer_end(x) + logits = F.softmax(x, dim=1) + return logits + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return [optimizer], [] + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.cross_entropy(logits, y) + self.log('train_loss', loss, prog_bar=True) + self.log('train_acc', self.train_acc(logits, y), prog_bar=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + self.log('val_loss', F.cross_entropy(logits, y), prog_bar=False) + self.log('val_acc', self.valid_acc(logits, y), prog_bar=True) + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + self.log('test_loss', F.cross_entropy(logits, y), prog_bar=False) + self.log('test_acc', self.test_acc(logits, y), prog_bar=True) + + +class RegressionModel(LightningModule): + + def __init__(self): + super().__init__() + setattr(self, "layer_0", nn.Linear(16, 64)) + setattr(self, "layer_0a", torch.nn.ReLU()) + for i in range(1, 3): + setattr(self, f"layer_{i}", nn.Linear(64, 64)) + setattr(self, f"layer_{i}a", torch.nn.ReLU()) + setattr(self, "layer_end", nn.Linear(64, 1)) + + self.train_mse = MeanSquaredError() + self.valid_mse = MeanSquaredError() + self.test_mse = MeanSquaredError() + + def forward(self, x): + x = self.layer_0(x) + x = self.layer_0a(x) + x = self.layer_1(x) + x = self.layer_1a(x) + x = self.layer_2(x) + x = self.layer_2a(x) + x = self.layer_end(x) + return x + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0.01) + return [optimizer], [] + + def training_step(self, batch, batch_idx): + x, y = batch + out = self.forward(x) + loss = F.mse_loss(out, y) + self.log('train_loss', loss, prog_bar=False) + self.log('train_MSE', self.train_mse(out, y), prog_bar=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + out = self.forward(x) + self.log('val_loss', F.mse_loss(out, y), prog_bar=False) + self.log('val_MSE', self.valid_mse(out, y), prog_bar=True) + + def test_step(self, batch, batch_idx): + x, y = batch + out = self.forward(x) + self.log('test_loss', F.mse_loss(out, y), prog_bar=False) + self.log('test_MSE', self.test_mse(out, y), prog_bar=True) diff --git a/tests/helpers/test_datasets.py b/tests/helpers/test_datasets.py new file mode 100644 index 00000000000000..8c866bdbab789c --- /dev/null +++ b/tests/helpers/test_datasets.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle + +import cloudpickle +import pytest + +from tests import PATH_DATASETS +from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST + + +@pytest.mark.parametrize( + 'dataset_cls,args', [ + (MNIST, dict(root=PATH_DATASETS)), + (TrialMNIST, dict(root=PATH_DATASETS)), + (AverageDataset, dict()), + ] +) +def test_pickling_dataset_mnist(tmpdir, dataset_cls, args): + mnist = dataset_cls(**args) + + mnist_pickled = pickle.dumps(mnist) + pickle.loads(mnist_pickled) + # assert vars(mnist) == vars(mnist_loaded) + + mnist_pickled = cloudpickle.dumps(mnist) + cloudpickle.loads(mnist_pickled) + # assert vars(mnist) == vars(mnist_loaded) diff --git a/tests/helpers/test_models.py b/tests/helpers/test_models.py new file mode 100644 index 00000000000000..e4bb7e7df08277 --- /dev/null +++ b/tests/helpers/test_models.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest + +from pytorch_lightning import Trainer +from tests.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN +from tests.helpers.boring_model import BoringModel +from tests.helpers.datamodules import ClassifDataModule, RegressDataModule +from tests.helpers.simple_models import ClassificationModel, RegressionModel + + +@pytest.mark.parametrize( + "data_class,model_class", [ + (None, BoringModel), + (None, BasicGAN), + (None, ParityModuleRNN), + (None, ParityModuleMNIST), + (ClassifDataModule, ClassificationModel), + (RegressDataModule, RegressionModel), + ] +) +def test_models(tmpdir, data_class, model_class): + """Test simple models""" + dm = data_class() if data_class else data_class + model = model_class() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + trainer.fit(model, datamodule=dm) + trainer.test(model, datamodule=dm) + + model.to_torchscript() + if data_class: + model.to_onnx(os.path.join(tmpdir, 'my-model.onnx'), input_sample=dm.sample) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py new file mode 100644 index 00000000000000..f5c1726a423bbb --- /dev/null +++ b/tests/helpers/utils.py @@ -0,0 +1,132 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import os +import traceback +from contextlib import contextmanager +from typing import Optional + +import pytest + +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger +from tests import _TEMP_PATH, RANDOM_PORTS +from tests.base.model_template import EvalModelTemplate + + +def get_default_logger(save_dir, version=None): + # set up logger object without actually saving logs + logger = TensorBoardLogger(save_dir, name='lightning_logs', version=version) + return logger + + +def get_data_path(expt_logger, path_dir=None): + # some calls contain only experiment not complete logger + + # each logger has to have these attributes + name, version = expt_logger.name, expt_logger.version + + # only the test-tube experiment has such attribute + if isinstance(expt_logger, TestTubeLogger): + expt = expt_logger.experiment if hasattr(expt_logger, 'experiment') else expt_logger + return expt.get_data_path(name, version) + + # the other experiments... + if not path_dir: + if hasattr(expt_logger, 'save_dir') and expt_logger.save_dir: + path_dir = expt_logger.save_dir + else: + path_dir = _TEMP_PATH + path_expt = os.path.join(path_dir, name, 'version_%s' % version) + + # try if the new sub-folder exists, typical case for test-tube + if not os.path.isdir(path_expt): + path_expt = path_dir + return path_expt + + +def load_model_from_checkpoint(logger, root_weights_dir, module_class=EvalModelTemplate): + trained_model = module_class.load_from_checkpoint(root_weights_dir) + assert trained_model is not None, 'loading model failed' + return trained_model + + +def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): + # this model should get 0.80+ acc + acc = trainer.callback_metrics[key] + assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" + + +def reset_seed(seed=0): + seed_everything(seed) + + +def set_random_master_port(): + reset_seed() + port = RANDOM_PORTS.pop() + os.environ['MASTER_PORT'] = str(port) + + +def init_checkpoint_callback(logger): + checkpoint = ModelCheckpoint(dirpath=logger.save_dir) + return checkpoint + + +def pl_multi_process_test(func): + """Wrapper for running multi-processing tests.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + + from multiprocessing import Process, Queue + queue = Queue() + + def inner_f(queue, **kwargs): + try: + func(**kwargs) + queue.put(1) + except Exception: + _trace = traceback.format_exc() + print(_trace) + # code 17 means RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : + # Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14) + if "terminated with exit code 17" in _trace: + queue.put(1) + else: + queue.put(-1) + + proc = Process(target=inner_f, args=(queue, ), kwargs=kwargs) + proc.start() + proc.join() + + result = queue.get() + assert result == 1, 'expected 1, but returned %s' % result + + return wrapper + + +@contextmanager +def no_warning_call(warning_type, match: Optional[str] = None): + with pytest.warns(None) as record: + yield + + try: + w = record.pop(warning_type) + if not (match and match in str(w.message)): + return + except AssertionError: + # no warning raised + return + raise AssertionError(f"`{warning_type}` was raised: {w}") diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 06e93fa6a23f41..eefeac31bffb88 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -1,13 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect +import os import pickle +from unittest import mock +from unittest.mock import ANY import pytest +import torch -import tests.base.utils as tutils -from pytorch_lightning import Trainer +import tests.helpers.utils as tutils +from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import ( - TensorBoardLogger, MLFlowLogger, NeptuneLogger, TestTubeLogger, CometLogger) -from tests.base import EvalModelTemplate + CometLogger, + MLFlowLogger, + NeptuneLogger, + TensorBoardLogger, + TestTubeLogger, + WandbLogger, +) +from pytorch_lightning.loggers.base import DummyExperiment +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf +from tests.loggers.test_comet import _patch_comet_atexit +from tests.loggers.test_mlflow import mock_mlflow_run_creation def _get_logger_args(logger_class, save_dir): @@ -16,28 +44,64 @@ def _get_logger_args(logger_class, save_dir): logger_args.update(save_dir=str(save_dir)) if 'offline_mode' in inspect.getfullargspec(logger_class).args: logger_args.update(offline_mode=True) + if 'offline' in inspect.getfullargspec(logger_class).args: + logger_args.update(offline=True) return logger_args -@pytest.mark.parametrize("logger_class", [ - TensorBoardLogger, - CometLogger, - MLFlowLogger, - NeptuneLogger, - TestTubeLogger, - # TrainsLogger, # TODO: add this one - # WandbLogger, # TODO: add this one -]) -def test_loggers_fit_test(tmpdir, monkeypatch, logger_class): - """Verify that basic functionality of all loggers.""" - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - import atexit - monkeypatch.setattr(atexit, 'register', lambda _: None) - - model = EvalModelTemplate(tutils.get_default_hparams()) +def _instantiate_logger(logger_class, save_idr, **override_kwargs): + args = _get_logger_args(logger_class, save_idr) + args.update(**override_kwargs) + logger = logger_class(**args) + return logger + + +def test_loggers_fit_test_all(tmpdir, monkeypatch): + """ Verify that basic functionality of all loggers. """ + + _test_loggers_fit_test(tmpdir, TensorBoardLogger) + + with mock.patch('pytorch_lightning.loggers.comet.comet_ml'), \ + mock.patch('pytorch_lightning.loggers.comet.CometOfflineExperiment'): + _patch_comet_atexit(monkeypatch) + _test_loggers_fit_test(tmpdir, CometLogger) + + with mock.patch('pytorch_lightning.loggers.mlflow.mlflow'), \ + mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient'): + _test_loggers_fit_test(tmpdir, MLFlowLogger) + + with mock.patch('pytorch_lightning.loggers.neptune.neptune'): + _test_loggers_fit_test(tmpdir, NeptuneLogger) + + with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'): + _test_loggers_fit_test(tmpdir, TestTubeLogger) + + with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb: + wandb.run = None + wandb.init().step = 0 + _test_loggers_fit_test(tmpdir, WandbLogger) + + +def _test_loggers_fit_test(tmpdir, logger_class): + + class CustomModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_some_val', loss) + return {"loss": loss} + + def validation_epoch_end(self, outputs) -> None: + avg_val_loss = torch.stack([x['x'] for x in outputs]).mean() + self.log_dict({'early_stop_on': avg_val_loss, 'val_loss': avg_val_loss**0.5}) + + def test_epoch_end(self, outputs) -> None: + avg_test_loss = torch.stack([x["y"] for x in outputs]).mean() + self.log('test_loss', avg_test_loss) class StoreHistoryLogger(logger_class): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.history = [] @@ -49,50 +113,295 @@ def log_metrics(self, metrics, step): logger_args = _get_logger_args(logger_class, tmpdir) logger = StoreHistoryLogger(**logger_args) + if logger_class == WandbLogger: + # required mocks for Trainer + logger.experiment.id = 'foo' + logger.experiment.project_name.return_value = 'bar' + + if logger_class == CometLogger: + logger.experiment.id = 'foo' + logger.experiment.project_name = 'bar' + + if logger_class == TestTubeLogger: + logger.experiment.version = 'foo' + logger.experiment.name = 'bar' + + if logger_class == MLFlowLogger: + logger = mock_mlflow_run_creation(logger, experiment_id="foo", run_id="bar") + + model = CustomModel() trainer = Trainer( max_epochs=1, logger=logger, - train_percent_check=0.2, - val_percent_check=0.5, - fast_dev_run=True, + limit_train_batches=1, + limit_val_batches=1, + log_every_n_steps=1, + default_root_dir=tmpdir, ) trainer.fit(model) - trainer.test() log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history] - assert log_metric_names == [(0, ['epoch', 'val_acc', 'val_loss']), - (0, ['epoch', 'train_some_val']), - (1, ['epoch', 'test_acc', 'test_loss'])] + if logger_class == TensorBoardLogger: + expected = [ + (0, ['hp_metric']), + (0, ['epoch', 'train_some_val']), + (0, ['early_stop_on', 'epoch', 'val_loss']), + (0, ['hp_metric']), + (1, ['epoch', 'test_loss']), + ] + assert log_metric_names == expected + else: + expected = [ + (0, ['epoch', 'train_some_val']), + (0, ['early_stop_on', 'epoch', 'val_loss']), + (1, ['epoch', 'test_loss']), + ] + assert log_metric_names == expected -@pytest.mark.parametrize("logger_class", [ - TensorBoardLogger, - CometLogger, - MLFlowLogger, - NeptuneLogger, - TestTubeLogger, - # TrainsLogger, # TODO: add this one - # WandbLogger, # TODO: add this one -]) -def test_loggers_pickle(tmpdir, monkeypatch, logger_class): +def test_loggers_save_dir_and_weights_save_path_all(tmpdir, monkeypatch): + """ Test the combinations of save_dir, weights_save_path and default_root_dir. """ + + _test_loggers_save_dir_and_weights_save_path(tmpdir, TensorBoardLogger) + + with mock.patch('pytorch_lightning.loggers.comet.comet_ml'), \ + mock.patch('pytorch_lightning.loggers.comet.CometOfflineExperiment'): + _patch_comet_atexit(monkeypatch) + _test_loggers_save_dir_and_weights_save_path(tmpdir, CometLogger) + + with mock.patch('pytorch_lightning.loggers.mlflow.mlflow'), \ + mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient'): + _test_loggers_save_dir_and_weights_save_path(tmpdir, MLFlowLogger) + + with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'): + _test_loggers_save_dir_and_weights_save_path(tmpdir, TestTubeLogger) + + with mock.patch('pytorch_lightning.loggers.wandb.wandb'): + _test_loggers_save_dir_and_weights_save_path(tmpdir, WandbLogger) + + +def _test_loggers_save_dir_and_weights_save_path(tmpdir, logger_class): + + class TestLogger(logger_class): + # for this test it does not matter what these attributes are + # so we standardize them to make testing easier + @property + def version(self): + return 'version' + + @property + def name(self): + return 'name' + + model = BoringModel() + trainer_args = dict( + default_root_dir=tmpdir, + max_steps=1, + ) + + # no weights_save_path given + save_dir = tmpdir / 'logs' + weights_save_path = None + logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) + trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) + trainer.fit(model) + assert trainer.weights_save_path == trainer.default_root_dir + assert trainer.checkpoint_callback.dirpath == os.path.join(logger.save_dir, 'name', 'version', 'checkpoints') + assert trainer.default_root_dir == tmpdir + + # with weights_save_path given, the logger path and checkpoint path should be different + save_dir = tmpdir / 'logs' + weights_save_path = tmpdir / 'weights' + logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) + trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) + trainer.fit(model) + assert trainer.weights_save_path == weights_save_path + assert trainer.logger.save_dir == save_dir + assert trainer.checkpoint_callback.dirpath == weights_save_path / 'name' / 'version' / 'checkpoints' + assert trainer.default_root_dir == tmpdir + + # no logger given + weights_save_path = tmpdir / 'weights' + trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path) + trainer.fit(model) + assert trainer.weights_save_path == weights_save_path + assert trainer.checkpoint_callback.dirpath == weights_save_path / 'checkpoints' + assert trainer.default_root_dir == tmpdir + + +@pytest.mark.parametrize( + "logger_class", + [ + CometLogger, + MLFlowLogger, + NeptuneLogger, + TensorBoardLogger, + TestTubeLogger, + # The WandbLogger gets tested for pickling in its own test. + ] +) +def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class): + """ Test that the logger objects can be pickled. This test only makes sense if the packages are installed. """ + _patch_comet_atexit(monkeypatch) + try: + _test_loggers_pickle(tmpdir, monkeypatch, logger_class) + except (ImportError, ModuleNotFoundError): + pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") + + +def _test_loggers_pickle(tmpdir, monkeypatch, logger_class): """Verify that pickling trainer with logger works.""" - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - import atexit - monkeypatch.setattr(atexit, 'register', lambda _: None) + _patch_comet_atexit(monkeypatch) logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) + # this can cause pickle error if the experiment object is not picklable + # the logger needs to remove it from the state before pickle + _ = logger.experiment + # test pickling loggers pickle.dumps(logger) trainer = Trainer( max_epochs=1, - logger=logger + logger=logger, ) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({'acc': 1.0}) + + # make sure we restord properly + assert trainer2.logger.name == logger.name + assert trainer2.logger.save_dir == logger.save_dir + + +@pytest.mark.parametrize( + "extra_params", [ + pytest.param(dict(max_epochs=1, auto_scale_batch_size=True), id='Batch-size-Finder'), + pytest.param(dict(max_epochs=3, auto_lr_find=True), id='LR-Finder'), + ] +) +def test_logger_reset_correctly(tmpdir, extra_params): + """ Test that the tuners do not alter the logger reference """ + + class CustomModel(BoringModel): + + def __init__(self, lr=0.1, batch_size=1): + super().__init__() + self.save_hyperparameters() + + tutils.reset_seed() + model = CustomModel() + trainer = Trainer( + default_root_dir=tmpdir, + **extra_params, + ) + logger1 = trainer.logger + trainer.tune(model) + logger2 = trainer.logger + logger3 = model.logger + + assert logger1 == logger2, \ + 'Finder altered the logger of trainer' + assert logger2 == logger3, \ + 'Finder altered the logger of model' + + +class RankZeroLoggerCheck(Callback): + # this class has to be defined outside the test function, otherwise we get pickle error + # due to the way ddp process is launched + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + is_dummy = isinstance(trainer.logger.experiment, DummyExperiment) + if trainer.is_global_zero: + assert not is_dummy + else: + assert is_dummy + assert pl_module.logger.experiment.something(foo="bar") is None + + +@pytest.mark.parametrize( + "logger_class", [ + CometLogger, + MLFlowLogger, + NeptuneLogger, + TensorBoardLogger, + TestTubeLogger, + ] +) +@RunIf(skip_windows=True) +def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class): + """ Test that loggers get replaced by dummy loggers on global rank > 0""" + _patch_comet_atexit(monkeypatch) + try: + _test_logger_created_on_rank_zero_only(tmpdir, logger_class) + except (ImportError, ModuleNotFoundError): + pytest.xfail(f"multi-process test requires {logger_class.__class__} dependencies to be installed.") + + +def _test_logger_created_on_rank_zero_only(tmpdir, logger_class): + logger_args = _get_logger_args(logger_class, tmpdir) + logger = logger_class(**logger_args) + model = BoringModel() + trainer = Trainer( + logger=logger, + default_root_dir=tmpdir, + accelerator='ddp_cpu', + num_processes=2, + max_steps=1, + checkpoint_callback=True, + callbacks=[RankZeroLoggerCheck()], + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +def test_logger_with_prefix_all(tmpdir, monkeypatch): + """ + Test that prefix is added at the beginning of the metric keys. + """ + prefix = 'tmp' + + # Comet + with mock.patch('pytorch_lightning.loggers.comet.comet_ml'), \ + mock.patch('pytorch_lightning.loggers.comet.CometOfflineExperiment'): + _patch_comet_atexit(monkeypatch) + logger = _instantiate_logger(CometLogger, save_idr=tmpdir, prefix=prefix) + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0) + + # MLflow + with mock.patch('pytorch_lightning.loggers.mlflow.mlflow'), \ + mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient'): + logger = _instantiate_logger(MLFlowLogger, save_idr=tmpdir, prefix=prefix) + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0) + + # Neptune + with mock.patch('pytorch_lightning.loggers.neptune.neptune'): + logger = _instantiate_logger(NeptuneLogger, save_idr=tmpdir, prefix=prefix) + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.log_metric.assert_called_once_with("tmp-test", 1.0) + + # TensorBoard + with mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter'): + logger = _instantiate_logger(TensorBoardLogger, save_idr=tmpdir, prefix=prefix) + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0) + + # TestTube + with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'): + logger = _instantiate_logger(TestTubeLogger, save_idr=tmpdir, prefix=prefix) + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.log.assert_called_once_with({"tmp-test": 1.0}, global_step=0) + + # WandB + with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb: + logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix) + wandb.run = None + wandb.init().step = 0 + logger.log_metrics({"test": 1.0}, step=0) + logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0}) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 595ca0ab093968..cf3a0cb74b3f44 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -1,13 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pickle +from argparse import Namespace +from typing import Optional from unittest.mock import MagicMock import numpy as np -import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection +from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_only -from tests.base import EvalModelTemplate +from tests.helpers import BoringModel def test_logger_collection(): @@ -22,16 +38,27 @@ def test_logger_collection(): assert logger.experiment[0] == mock1.experiment assert logger.experiment[1] == mock2.experiment + assert logger.save_dir is None + + logger.update_agg_funcs({'test': np.mean}, np.sum) + mock1.update_agg_funcs.assert_called_once_with({'test': np.mean}, np.sum) + mock2.update_agg_funcs.assert_called_once_with({'test': np.mean}, np.sum) + + logger.agg_and_log_metrics({'test': 2.0}, 4) + mock1.agg_and_log_metrics.assert_called_once_with({'test': 2.0}, 4) + mock2.agg_and_log_metrics.assert_called_once_with({'test': 2.0}, 4) + logger.close() mock1.close.assert_called_once() mock2.close.assert_called_once() class CustomLogger(LightningLoggerBase): + def __init__(self): super().__init__() self.hparams_logged = None - self.metrics_logged = None + self.metrics_logged = {} self.finalized = False @property @@ -50,6 +77,14 @@ def log_metrics(self, metrics, step): def finalize(self, status): self.finalized_status = status + @property + def save_dir(self) -> Optional[str]: + """ + Return the root directory where experiment logs get saved, or `None` if the logger does not + save data locally. + """ + return None + @property def name(self): return "name" @@ -60,45 +95,58 @@ def version(self): def test_custom_logger(tmpdir): - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(tutils.get_default_hparams()) - logger = CustomLogger() + class CustomModel(BoringModel): + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + + logger = CustomLogger() + model = CustomModel() trainer = Trainer( - max_epochs=1, - train_percent_check=0.05, + max_steps=2, + log_every_n_steps=1, logger=logger, - default_root_dir=tmpdir + default_root_dir=tmpdir, ) - result = trainer.fit(model) - assert result == 1, "Training failed" - assert logger.hparams_logged == hparams + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert logger.hparams_logged == model.hparams assert logger.metrics_logged != {} assert logger.finalized_status == "success" def test_multiple_loggers(tmpdir): - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) + class CustomModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + + model = CustomModel() logger1 = CustomLogger() logger2 = CustomLogger() trainer = Trainer( - max_epochs=1, - train_percent_check=0.05, + max_steps=2, + log_every_n_steps=1, logger=[logger1, logger2], - default_root_dir=tmpdir + default_root_dir=tmpdir, ) - result = trainer.fit(model) - assert result == 1, "Training failed" + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert logger1.hparams_logged == hparams + assert logger1.hparams_logged == model.hparams assert logger1.metrics_logged != {} assert logger1.finalized_status == "success" - assert logger2.hparams_logged == hparams + assert logger2.hparams_logged == model.hparams assert logger2.metrics_logged != {} assert logger2.finalized_status == "success" @@ -109,48 +157,48 @@ def test_multiple_loggers_pickle(tmpdir): logger1 = CustomLogger() logger2 = CustomLogger() - trainer = Trainer(max_epochs=1, logger=[logger1, logger2]) + trainer = Trainer(logger=[logger1, logger2], ) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}, 0) - assert logger1.metrics_logged != {} - assert logger2.metrics_logged != {} + assert trainer2.logger[0].metrics_logged == {"acc": 1.0} + assert trainer2.logger[1].metrics_logged == {"acc": 1.0} def test_adding_step_key(tmpdir): - logged_step = 0 - def _validation_epoch_end(outputs): - nonlocal logged_step - logged_step += 1 - return {"log": {"step": logged_step, "val_acc": logged_step / 10}} + class CustomTensorBoardLogger(TensorBoardLogger): - def _training_epoch_end(outputs): - nonlocal logged_step - logged_step += 1 - return {"log": {"step": logged_step, "train_acc": logged_step / 10}} + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logged_step = 0 - def _log_metrics_decorator(log_metrics_fn): - def decorated(metrics, step): + def log_metrics(self, metrics, step): if "val_acc" in metrics: - assert step == logged_step - return log_metrics_fn(metrics, step) + assert step == self.logged_step - return decorated + super().log_metrics(metrics, step) - model = EvalModelTemplate(tutils.get_default_hparams()) - model.validation_epoch_end = _validation_epoch_end - model.training_epoch_end = _training_epoch_end + class CustomModel(BoringModel): + + def training_epoch_end(self, outputs): + self.logger.logged_step += 1 + self.log_dict({"step": self.logger.logged_step, "train_acc": self.logger.logged_step / 10}) + + def validation_epoch_end(self, outputs): + self.logger.logged_step += 1 + self.log_dict({"step": self.logger.logged_step, "val_acc": self.logger.logged_step / 10}) + + model = CustomModel() trainer = Trainer( - max_epochs=4, + max_epochs=3, + logger=CustomTensorBoardLogger(save_dir=tmpdir), default_root_dir=tmpdir, - train_percent_check=0.001, - val_percent_check=0.01, + limit_train_batches=0.1, + limit_val_batches=0.1, num_sanity_val_steps=0, ) - trainer.logger.log_metrics = _log_metrics_decorator( - trainer.logger.log_metrics) trainer.fit(model) @@ -158,6 +206,7 @@ def test_with_accumulate_grad_batches(): """Checks if the logging is performed once for `accumulate_grad_batches` steps.""" class StoreHistoryLogger(CustomLogger): + def __init__(self): super().__init__() self.history = {} @@ -177,3 +226,63 @@ def log_metrics(self, metrics, step): assert logger.history == {0: {'loss': 0.5623850983416314}} logger.close() assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}} + + +def test_dummyexperiment_support_indexing(): + """ Test that the DummyExperiment can imitate indexing the experiment in a LoggerCollection. """ + experiment = DummyExperiment() + assert experiment[0] == experiment + + +def test_dummylogger_support_indexing(): + """ Test that the DummyLogger can imitate indexing of a LoggerCollection. """ + logger = DummyLogger() + assert logger[0] == logger + + +def test_dummylogger_noop_method_calls(): + """ Test that the DummyLogger methods can be called with arbitrary arguments. """ + logger = DummyLogger() + logger.log_hyperparams("1", 2, three="three") + logger.log_metrics("1", 2, three="three") + + +def test_np_sanitization(): + + class CustomParamsLogger(CustomLogger): + + def __init__(self): + super().__init__() + self.logged_params = None + + @rank_zero_only + def log_hyperparams(self, params): + params = self._convert_params(params) + params = self._sanitize_params(params) + self.logged_params = params + + logger = CustomParamsLogger() + np_params = { + "np.bool_": np.bool_(1), + "np.byte": np.byte(2), + "np.intc": np.intc(3), + "np.int_": np.int_(4), + "np.longlong": np.longlong(5), + "np.single": np.single(6.0), + "np.double": np.double(8.9), + "np.csingle": np.csingle(7 + 2j), + "np.cdouble": np.cdouble(9 + 4j), + } + sanitized_params = { + "np.bool_": True, + "np.byte": 2, + "np.intc": 3, + "np.int_": 4, + "np.longlong": 5, + "np.single": 6.0, + "np.double": 8.9, + "np.csingle": "(7+2j)", + "np.cdouble": "(9+4j)", + } + logger.log_hyperparams(Namespace(**np_params)) + assert logger.logged_params == sanitized_params diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index aeab10cd0fbd95..1d686c6ba8c158 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -1,52 +1,51 @@ -from unittest.mock import patch +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import DEFAULT, patch import pytest +from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel -def test_comet_logger_online(): +def _patch_comet_atexit(monkeypatch): + """ Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it. """ + import atexit + monkeypatch.setattr(atexit, "register", lambda _: None) + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_online(comet): """Test comet online with mocks.""" # Test api_key given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: - logger = CometLogger( - api_key='key', - workspace='dummy-test', - project_name='general' - ) + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: + logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment - comet.assert_called_once_with( - api_key='key', - workspace='dummy-test', - project_name='general' - ) + comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test both given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: - logger = CometLogger( - save_dir='test', - api_key='key', - workspace='dummy-test', - project_name='general' - ) + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: + logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment - comet.assert_called_once_with( - api_key='key', - workspace='dummy-test', - project_name='general' - ) - - # Test neither given - with pytest.raises(MisconfigurationException): - CometLogger( - workspace='dummy-test', - project_name='general' - ) + comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test already exists with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing: @@ -55,26 +54,171 @@ def test_comet_logger_online(): experiment_name='experiment', api_key='key', workspace='dummy-test', - project_name='general' + project_name='general', ) _ = logger.experiment comet_existing.assert_called_once_with( - api_key='key', - workspace='dummy-test', - project_name='general', - previous_experiment='test' + api_key='key', workspace='dummy-test', project_name='general', previous_experiment='test' ) comet_existing().set_name.assert_called_once_with('experiment') with patch('pytorch_lightning.loggers.comet.API') as api: - CometLogger( - api_key='key', - workspace='dummy-test', - project_name='general', - rest_api_key='rest' - ) + CometLogger(api_key='key', workspace='dummy-test', project_name='general', rest_api_key='rest') api.assert_called_once_with('rest') + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_no_api_key_given(comet): + """ Test that CometLogger fails to initialize if both api key and save_dir are missing. """ + with pytest.raises(MisconfigurationException, match='requires either api_key or save_dir'): + comet.config.get_api_key.return_value = None + CometLogger(workspace='dummy-test', project_name='general') + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_experiment_name(comet): + """Test that Comet Logger experiment name works correctly.""" + + api_key = "key" + experiment_name = "My Name" + + # Test api_key given + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: + logger = CometLogger( + api_key=api_key, + experiment_name=experiment_name, + ) + assert logger._experiment is None + + _ = logger.experiment + comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) + comet_experiment().set_name.assert_called_once_with(experiment_name) + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_manual_experiment_key(comet): + """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" + + api_key = "key" + experiment_key = "96346da91469407a85641afe5766b554" + + instantation_environ = {} + + def save_os_environ(*args, **kwargs): + nonlocal instantation_environ + instantation_environ = os.environ.copy() + + return DEFAULT + + # Test api_key given + with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): + with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment: + logger = CometLogger(api_key=api_key) + assert logger.version == experiment_key + assert logger._experiment is None + + _ = logger.experiment + comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) + + assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key + + +@patch('pytorch_lightning.loggers.comet.CometOfflineExperiment') +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch): + """ Test that the logger creates the folders and files in the right place. """ + _patch_comet_atexit(monkeypatch) + + comet.config.get_api_key.return_value = None + comet.generate_guid.return_value = "4321" + + logger = CometLogger(project_name='test', save_dir=tmpdir) + assert not os.listdir(tmpdir) + assert logger.mode == 'offline' + assert logger.save_dir == tmpdir + assert logger.name == 'test' + assert logger.version == "4321" + + _ = logger.experiment + + comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test') + + # mock return values of experiment + logger.experiment.id = '1' + logger.experiment.project_name = 'test' + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3) + assert trainer.log_dir == logger.save_dir + trainer.fit(model) + + assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=2.ckpt'} + assert trainer.log_dir == logger.save_dir + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_name_default(comet): + """ Test that CometLogger.name don't create an Experiment and returns a default value. """ + + api_key = "key" + + with patch('pytorch_lightning.loggers.comet.CometExperiment'): + logger = CometLogger(api_key=api_key) + assert logger._experiment is None + assert logger.name == "comet-default" + assert logger._experiment is None + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_name_project_name(comet): + """ Test that CometLogger.name does not create an Experiment and returns project name if passed. """ + + api_key = "key" + project_name = "My Project Name" + + with patch('pytorch_lightning.loggers.comet.CometExperiment'): + logger = CometLogger(api_key=api_key, project_name=project_name) + assert logger._experiment is None + assert logger.name == project_name + assert logger._experiment is None + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_version_without_experiment(comet): + """ Test that CometLogger.version does not create an Experiment. """ + + api_key = "key" + experiment_name = "My Name" + comet.generate_guid.return_value = "1234" + + with patch('pytorch_lightning.loggers.comet.CometExperiment'): + logger = CometLogger(api_key=api_key, experiment_name=experiment_name) + assert logger._experiment is None + + first_version = logger.version + assert first_version is not None + assert logger.version == first_version + assert logger._experiment is None + + _ = logger.experiment + + logger.reset_experiment() + + second_version = logger.version == "1234" + assert second_version is not None + assert second_version != first_version + + +@patch("pytorch_lightning.loggers.comet.CometExperiment") +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): + """ Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """ + _patch_comet_atexit(monkeypatch) + logger = CometLogger(project_name="test", save_dir=tmpdir) + logger.log_metrics({"test": 1, "epoch": 1}, step=123) + logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) diff --git a/tests/loggers/test_csv.py b/tests/loggers/test_csv.py new file mode 100644 index 00000000000000..dcdb6421c517f5 --- /dev/null +++ b/tests/loggers/test_csv.py @@ -0,0 +1,114 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from argparse import Namespace + +import pytest +import torch + +from pytorch_lightning.core.saving import load_hparams_from_yaml +from pytorch_lightning.loggers import CSVLogger +from pytorch_lightning.loggers.csv_logs import ExperimentWriter + + +def test_file_logger_automatic_versioning(tmpdir): + """Verify that automatic versioning works""" + + root_dir = tmpdir.mkdir("exp") + root_dir.mkdir("version_0") + root_dir.mkdir("version_1") + + logger = CSVLogger(save_dir=tmpdir, name="exp") + + assert logger.version == 2 + + +def test_file_logger_manual_versioning(tmpdir): + """Verify that manual versioning works""" + + root_dir = tmpdir.mkdir("exp") + root_dir.mkdir("version_0") + root_dir.mkdir("version_1") + root_dir.mkdir("version_2") + + logger = CSVLogger(save_dir=tmpdir, name="exp", version=1) + + assert logger.version == 1 + + +def test_file_logger_named_version(tmpdir): + """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ + + exp_name = "exp" + tmpdir.mkdir(exp_name) + expected_version = "2020-02-05-162402" + + logger = CSVLogger(save_dir=tmpdir, name=exp_name, version=expected_version) + logger.log_hyperparams({"a": 1, "b": 2}) + logger.save() + assert logger.version == expected_version + assert os.listdir(tmpdir / exp_name) == [expected_version] + assert os.listdir(tmpdir / exp_name / expected_version) + + +@pytest.mark.parametrize("name", ['', None]) +def test_file_logger_no_name(tmpdir, name): + """Verify that None or empty name works""" + logger = CSVLogger(save_dir=tmpdir, name=name) + logger.save() + assert logger.root_dir == tmpdir + assert os.listdir(tmpdir / 'version_0') + + +@pytest.mark.parametrize("step_idx", [10, None]) +def test_file_logger_log_metrics(tmpdir, step_idx): + logger = CSVLogger(tmpdir) + metrics = { + "float": 0.3, + "int": 1, + "FloatTensor": torch.tensor(0.1), + "IntTensor": torch.tensor(1), + } + logger.log_metrics(metrics, step_idx) + logger.save() + + path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) + with open(path_csv, 'r') as fp: + lines = fp.readlines() + assert len(lines) == 2 + assert all([n in lines[0] for n in metrics]) + + +def test_file_logger_log_hyperparams(tmpdir): + logger = CSVLogger(tmpdir) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": { + 'a': { + 'b': 'c' + } + }, + "list": [1, 2, 3], + "namespace": Namespace(foo=Namespace(bar='buzz')), + "layer": torch.nn.BatchNorm1d + } + logger.log_hyperparams(hparams) + logger.save() + + path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE) + params = load_hparams_from_yaml(path_yaml) + assert all([n in params for n in hparams]) diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 81ce25ca6347ed..35bad766798b1d 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -1,9 +1,229 @@ -from pytorch_lightning.loggers import MLFlowLogger +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock +from unittest.mock import MagicMock +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger +from tests.helpers import BoringModel + + +def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): + """ Helper function to simulate mlflow client creating a new (or existing) experiment. """ + run = MagicMock() + run.info.run_id = run_id + logger._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name) + logger._mlflow_client.create_experiment = MagicMock(return_value=experiment_id) + logger._mlflow_client.create_run = MagicMock(return_value=run) + return logger + + +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_exists(client, mlflow, tmpdir): + """ Test launching three independent loggers with either same or different experiment name. """ + + run1 = MagicMock() + run1.info.run_id = "run-id-1" + + run2 = MagicMock() + run2.info.run_id = "run-id-2" + + run3 = MagicMock() + run3.info.run_id = "run-id-3" + + # simulate non-existing experiment creation + client.return_value.get_experiment_by_name = MagicMock(return_value=None) + client.return_value.create_experiment = MagicMock(return_value="exp-id-1") # experiment_id + client.return_value.create_run = MagicMock(return_value=run1) -def test_mlflow_logger_exists(tmpdir): - """Verify that basic functionality of mlflow logger works.""" logger = MLFlowLogger('test', save_dir=tmpdir) - # Test already exists + assert logger._experiment_id is None + assert logger._run_id is None + _ = logger.experiment + assert logger.experiment_id == "exp-id-1" + assert logger.run_id == "run-id-1" + assert logger.experiment.create_experiment.asset_called_once() + client.reset_mock(return_value=True) + + # simulate existing experiment returns experiment id + exp1 = MagicMock() + exp1.experiment_id = "exp-id-1" + client.return_value.get_experiment_by_name = MagicMock(return_value=exp1) + client.return_value.create_run = MagicMock(return_value=run2) + + # same name leads to same experiment id, but different runs get recorded logger2 = MLFlowLogger('test', save_dir=tmpdir) - assert logger.run_id != logger2.run_id + assert logger2.experiment_id == logger.experiment_id + assert logger2.run_id == "run-id-2" + assert logger2.experiment.create_experiment.call_count == 0 + assert logger2.experiment.create_run.asset_called_once() + client.reset_mock(return_value=True) + + # simulate a 3rd experiment with new name + client.return_value.get_experiment_by_name = MagicMock(return_value=None) + client.return_value.create_experiment = MagicMock(return_value="exp-id-3") + client.return_value.create_run = MagicMock(return_value=run3) + + # logger with new experiment name causes new experiment id and new run id to be created + logger3 = MLFlowLogger('new', save_dir=tmpdir) + assert logger3.experiment_id == "exp-id-3" != logger.experiment_id + assert logger3.run_id == "run-id-3" + + +@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_log_dir(client, mlflow, tmpdir): + """ Test that the trainer saves checkpoints in the logger's save dir. """ + + # simulate experiment creation with mlflow client mock + run = MagicMock() + run.info.run_id = "run-id" + client.return_value.get_experiment_by_name = MagicMock(return_value=None) + client.return_value.create_experiment = MagicMock(return_value="exp-id") + client.return_value.create_run = MagicMock(return_value=run) + + # test construction of default log dir path + logger = MLFlowLogger("test", save_dir=tmpdir) + assert logger.save_dir == tmpdir + assert logger.version == "run-id" + assert logger.name == "exp-id" + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + logger=logger, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=3, + ) + assert trainer.log_dir == logger.save_dir + trainer.fit(model) + assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'} + assert trainer.log_dir == logger.save_dir + + +def test_mlflow_logger_dirs_creation(tmpdir): + """ Test that the logger creates the folders and files in the right place. """ + if not _MLFLOW_AVAILABLE: + pytest.xfail("test for explicit file creation requires mlflow dependency to be installed.") + + assert not os.listdir(tmpdir) + logger = MLFlowLogger('test', save_dir=tmpdir) + assert logger.save_dir == tmpdir + assert set(os.listdir(tmpdir)) == {'.trash'} + run_id = logger.run_id + exp_id = logger.experiment_id + + # multiple experiment calls should not lead to new experiment folders + for i in range(2): + _ = logger.experiment + assert set(os.listdir(tmpdir)) == {'.trash', exp_id} + assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} + + class CustomModel(BoringModel): + + def training_epoch_end(self, *args, **kwargs): + super().training_epoch_end(*args, **kwargs) + self.log('epoch', self.current_epoch) + + model = CustomModel() + limit_batches = 5 + trainer = Trainer( + default_root_dir=tmpdir, + logger=logger, + max_epochs=1, + limit_train_batches=limit_batches, + limit_val_batches=limit_batches, + log_gpu_memory=True, + ) + trainer.fit(model) + assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} + assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') + assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() + assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') + assert os.listdir(trainer.checkpoint_callback.dirpath) == [f'epoch=0-step={limit_batches - 1}.ckpt'] + + +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): + """ + Test that the logger experiment_id retrieved only once. + """ + logger = MLFlowLogger('test', save_dir=tmpdir) + _ = logger.experiment + _ = logger.experiment + _ = logger.experiment + assert logger.experiment.get_experiment_by_name.call_count == 1 + + +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir): + """ + Test that the logger raises warning with special characters not accepted by MLFlow. + """ + logger = MLFlowLogger('test', save_dir=tmpdir) + metrics = {'[some_metric]': 10} + + with pytest.warns(RuntimeWarning, match='special characters in metric name'): + logger.log_metrics(metrics) + + +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): + """ + Test that the logger raises warning with special characters not accepted by MLFlow. + """ + logger = MLFlowLogger('test', save_dir=tmpdir) + value = 'test' * 100 + key = 'test_param' + params = {key: value} + + with pytest.warns(RuntimeWarning, match=f'Discard {key}={value}'): + logger.log_hyperparams(params) + + +@mock.patch('pytorch_lightning.loggers.mlflow.time') +@mock.patch('pytorch_lightning.loggers.mlflow.mlflow') +@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient') +def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): + """ + Test that the logger calls methods on the mlflow experiment correctly. + """ + time.return_value = 1 + + logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location') + logger._mlflow_client.get_experiment_by_name.return_value = None + + params = {'test': 'test_param'} + logger.log_hyperparams(params) + + logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param') + + metrics = {'some_metric': 10} + logger.log_metrics(metrics) + + logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None) + + logger._mlflow_client.create_experiment.assert_called_once_with( + name='test', + artifact_location='my_artifact_location', + ) diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index 2ca3eaf513da77..3ac763cc87b4f5 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -1,77 +1,122 @@ -from unittest.mock import patch, MagicMock +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, patch import torch -import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.loggers import NeptuneLogger -from tests.base import EvalModelTemplate +from tests.helpers import BoringModel @patch('pytorch_lightning.loggers.neptune.neptune') def test_neptune_online(neptune): - logger = NeptuneLogger(api_key='test', offline_mode=False, project_name='project') - neptune.init.assert_called_once_with(api_token='test', project_qualified_name='project') + logger = NeptuneLogger(api_key='test', project_name='project') - assert logger.name == neptune.create_experiment().name - assert logger.version == neptune.create_experiment().id + created_experiment = neptune.Session.with_default_backend().get_project().create_experiment() + + # It's important to check if the internal variable _experiment was initialized in __init__. + # Calling logger.experiment would cause a side-effect of initializing _experiment, + # if it wasn't already initialized. + assert logger._experiment is None + _ = logger.experiment + assert logger._experiment == created_experiment + assert logger.name == created_experiment.name + assert logger.version == created_experiment.id @patch('pytorch_lightning.loggers.neptune.neptune') -def test_neptune_additional_methods(neptune): +def test_neptune_existing_experiment(neptune): + logger = NeptuneLogger(experiment_id='TEST-123') + neptune.Session.with_default_backend().get_project().get_experiments.assert_not_called() + experiment = logger.experiment + neptune.Session.with_default_backend().get_project().get_experiments.assert_called_once_with(id='TEST-123') + assert logger.experiment_name == experiment.get_system_properties()['name'] + assert logger.params == experiment.get_parameters() + assert logger.properties == experiment.get_properties() + assert logger.tags == experiment.get_tags() + + +@patch('pytorch_lightning.loggers.neptune.neptune') +def test_neptune_offline(neptune): logger = NeptuneLogger(offline_mode=True) + neptune.Session.assert_not_called() + _ = logger.experiment + neptune.Session.assert_called_once_with(backend=neptune.OfflineBackend()) + assert logger.experiment == neptune.Session().get_project().create_experiment() + + +@patch('pytorch_lightning.loggers.neptune.neptune') +def test_neptune_additional_methods(neptune): + logger = NeptuneLogger(api_key='test', project_name='project') + + created_experiment = neptune.Session.with_default_backend().get_project().create_experiment() logger.log_metric('test', torch.ones(1)) - neptune.create_experiment().log_metric.assert_called_once_with('test', torch.ones(1)) - neptune.create_experiment().log_metric.reset_mock() + created_experiment.log_metric.assert_called_once_with('test', torch.ones(1)) + created_experiment.log_metric.reset_mock() logger.log_metric('test', 1.0) - neptune.create_experiment().log_metric.assert_called_once_with('test', 1.0) - neptune.create_experiment().log_metric.reset_mock() + created_experiment.log_metric.assert_called_once_with('test', 1.0) + created_experiment.log_metric.reset_mock() logger.log_metric('test', 1.0, step=2) - neptune.create_experiment().log_metric.assert_called_once_with('test', x=2, y=1.0) - neptune.create_experiment().log_metric.reset_mock() + created_experiment.log_metric.assert_called_once_with('test', x=2, y=1.0) + created_experiment.log_metric.reset_mock() logger.log_text('test', 'text') - neptune.create_experiment().log_metric.assert_called_once_with('test', 'text') - neptune.create_experiment().log_metric.reset_mock() + created_experiment.log_text.assert_called_once_with('test', 'text', step=None) + created_experiment.log_text.reset_mock() logger.log_image('test', 'image file') - neptune.create_experiment().log_image.assert_called_once_with('test', 'image file') - neptune.create_experiment().log_image.reset_mock() + created_experiment.log_image.assert_called_once_with('test', 'image file') + created_experiment.log_image.reset_mock() logger.log_image('test', 'image file', step=2) - neptune.create_experiment().log_image.assert_called_once_with('test', x=2, y='image file') - neptune.create_experiment().log_image.reset_mock() + created_experiment.log_image.assert_called_once_with('test', x=2, y='image file') + created_experiment.log_image.reset_mock() logger.log_artifact('file') - neptune.create_experiment().log_artifact.assert_called_once_with('file', None) + created_experiment.log_artifact.assert_called_once_with('file', None) logger.set_property('property', 10) - neptune.create_experiment().set_property.assert_called_once_with('property', 10) + created_experiment.set_property.assert_called_once_with('property', 10) logger.append_tags('one tag') - neptune.create_experiment().append_tags.assert_called_once_with('one tag') - neptune.create_experiment().append_tags.reset_mock() + created_experiment.append_tags.assert_called_once_with('one tag') + created_experiment.append_tags.reset_mock() logger.append_tags(['two', 'tags']) - neptune.create_experiment().append_tags.assert_called_once_with('two', 'tags') + created_experiment.append_tags.assert_called_once_with('two', 'tags') -def test_neptune_leave_open_experiment_after_fit(tmpdir): +@patch('pytorch_lightning.loggers.neptune.neptune') +def test_neptune_leave_open_experiment_after_fit(neptune, tmpdir): """Verify that neptune experiment was closed after training""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = BoringModel() def _run_training(logger): logger._experiment = MagicMock() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - train_percent_check=0.05, - logger=logger + limit_train_batches=0.05, + logger=logger, ) + assert trainer.log_dir is None trainer.fit(model) + assert trainer.log_dir is None return logger logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True)) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index a17cedc4a3b88b..1a85270c6dcbbd 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -1,30 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os from argparse import Namespace +from distutils.version import LooseVersion +from unittest import mock import pytest import torch +import yaml +from omegaconf import OmegaConf +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@RunIf(min_torch="1.5.0") +def test_tensorboard_hparams_reload(tmpdir): + + class CustomModel(BoringModel): + + def __init__(self, b1=0.5, b2=0.999): + super().__init__() + self.save_hyperparameters() + + trainer = Trainer(max_steps=1, default_root_dir=tmpdir) + model = CustomModel() + assert trainer.log_dir == trainer.logger.log_dir + trainer.fit(model) + + assert trainer.log_dir == trainer.logger.log_dir + folder_path = trainer.log_dir + + # make sure yaml is there + with open(os.path.join(folder_path, "hparams.yaml")) as file: + # The FullLoader parameter handles the conversion from YAML + # scalar values to Python the dictionary format + yaml_params = yaml.safe_load(file) + assert yaml_params["b1"] == 0.5 + assert yaml_params["b2"] == 0.999 + assert len(yaml_params.keys()) == 2 + + # verify artifacts + assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1 + + # verify tb logs + event_acc = EventAccumulator(folder_path) + event_acc.Reload() + + data_pt_1_5 = b'\x12\x1b"\x04\n\x02b1"\x04\n\x02b2*\r\n\x0b\x12\thp_metric' + data_pt_1_6 = b'\x12\x1f"\x06\n\x02b1 \x03"\x06\n\x02b2 \x03*\r\n\x0b\x12\thp_metric' + hparams_data = data_pt_1_6 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0") else data_pt_1_5 + + assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.plugin_name == 'hparams' + assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.content == hparams_data def test_tensorboard_automatic_versioning(tmpdir): """Verify that automatic versioning works""" - root_dir = tmpdir.mkdir("tb_versioning") - root_dir.mkdir("version_0") - root_dir.mkdir("version_1") + root_dir = tmpdir / "tb_versioning" + root_dir.mkdir() + (root_dir / "version_0").mkdir() + (root_dir / "version_1").mkdir() logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning") - assert logger.version == 2 def test_tensorboard_manual_versioning(tmpdir): """Verify that manual versioning works""" - root_dir = tmpdir.mkdir("tb_versioning") - root_dir.mkdir("version_0") - root_dir.mkdir("version_1") - root_dir.mkdir("version_2") + root_dir = tmpdir / "tb_versioning" + root_dir.mkdir() + (root_dir / "version_0").mkdir() + (root_dir / "version_1").mkdir() + (root_dir / "version_2").mkdir() logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1) @@ -34,22 +98,25 @@ def test_tensorboard_manual_versioning(tmpdir): def test_tensorboard_named_version(tmpdir): """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ - tmpdir.mkdir("tb_versioning") + name = "tb_versioning" + (tmpdir / name).mkdir() expected_version = "2020-02-05-162402" - logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=expected_version) - logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written + logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version) + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.version == expected_version - # Could also test existence of the directory but this fails - # in the "minimum requirements" test setup + assert os.listdir(tmpdir / name) == [expected_version] + assert os.listdir(tmpdir / name / expected_version) -@pytest.mark.parametrize("name", ['', None]) +@pytest.mark.parametrize("name", ["", None]) def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.root_dir == tmpdir + assert os.listdir(tmpdir / "version_0") @pytest.mark.parametrize("step_idx", [10, None]) @@ -59,7 +126,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): "float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), - "IntTensor": torch.tensor(1) + "IntTensor": torch.tensor(1), } logger.log_metrics(metrics, step_idx) @@ -71,25 +138,166 @@ def test_tensorboard_log_hyperparams(tmpdir): "int": 1, "string": "abc", "bool": True, - "dict": {'a': {'b': 'c'}}, + "dict": { + "a": { + "b": "c" + } + }, "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar='buzz')), - "layer": torch.nn.BatchNorm1d + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, } logger.log_hyperparams(hparams) def test_tensorboard_log_hparams_and_metrics(tmpdir): - logger = TensorBoardLogger(tmpdir) + logger = TensorBoardLogger(tmpdir, default_hp_metric=False) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": { + "a": { + "b": "c" + } + }, + "list": [1, 2, 3], + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, + } + metrics = {"abc": torch.tensor([0.54])} + logger.log_hyperparams(hparams, metrics) + + +def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): + logger = TensorBoardLogger(tmpdir, default_hp_metric=False) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True, - "dict": {'a': {'b': 'c'}}, + "dict": { + "a": { + "b": "c" + } + }, "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar='buzz')), - "layer": torch.nn.BatchNorm1d + # "namespace": Namespace(foo=Namespace(bar="buzz")), + # "layer": torch.nn.BatchNorm1d, } - metrics = {'abc': torch.tensor([0.54])} + hparams = OmegaConf.create(hparams) + + metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics) + + +@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)]) +def test_tensorboard_log_graph(tmpdir, example_input_array): + """ test that log graph works with both model.example_input_array and + if array is passed externaly + """ + model = BoringModel() + if example_input_array is not None: + model.example_input_array = None + + logger = TensorBoardLogger(tmpdir, log_graph=True) + logger.log_graph(model, example_input_array) + + +def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): + """ test that log graph throws warning if model.example_input_array is None """ + model = BoringModel() + model.example_input_array = None + logger = TensorBoardLogger(tmpdir, log_graph=True) + with pytest.warns( + UserWarning, + match='Could not log computational graph since the `model.example_input_array`' + ' attribute is not set or `input_array` was not given' + ): + logger.log_graph(model) + + +@mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics') +@pytest.mark.parametrize('expected', [ + ([5, 11, 17]), +]) +def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir): + """ + Tests to ensure that tensorboard log properly when accumulated_gradients > 1 + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self._count = 0 + self._indexes = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('count', self._count, on_step=True, on_epoch=True) + self.log('loss', loss, on_step=True, on_epoch=True) + + if not self.trainer.train_loop.should_accumulate(): + if self.trainer.logger_connector.should_update_logs: + self._indexes.append(self.trainer.global_step) + + return loss + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('val_loss', loss, on_step=True, on_epoch=True) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + model = TestModel() + model.training_epoch_end = None + model.validation_epoch_end = None + + logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False) + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=12, + limit_val_batches=0, + max_epochs=3, + gpus=0, + accumulate_grad_batches=2, + logger=[logger_0], + log_every_n_steps=3, + ) + trainer.fit(model) + + mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]] + assert mock_count_epochs == expected + + mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]] + assert model._indexes == mock_count_steps + + +@mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter') +def test_tensorboard_finalize(summary_writer, tmpdir): + """ Test that the SummaryWriter closes in finalize. """ + logger = TensorBoardLogger(save_dir=tmpdir) + logger.finalize("any") + summary_writer().flush.assert_called() + summary_writer().close.assert_called() + + +def test_tensorboard_save_hparams_to_yaml_once(tmpdir): + model = BoringModel() + logger = TensorBoardLogger(save_dir=tmpdir, default_hp_metric=False) + trainer = Trainer(max_steps=1, default_root_dir=tmpdir, logger=logger) + assert trainer.log_dir == trainer.logger.log_dir + trainer.fit(model) + + hparams_file = "hparams.yaml" + assert os.path.isfile(os.path.join(trainer.log_dir, hparams_file)) + assert not os.path.isfile(os.path.join(tmpdir, hparams_file)) diff --git a/tests/loggers/test_trains.py b/tests/loggers/test_trains.py deleted file mode 100644 index 738a0d9bcf8673..00000000000000 --- a/tests/loggers/test_trains.py +++ /dev/null @@ -1,50 +0,0 @@ -import pickle - -import tests.base.utils as tutils -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TrainsLogger -from tests.base import EvalModelTemplate - - -def test_trains_logger(tmpdir): - """Verify that basic functionality of TRAINS logger works.""" - model = EvalModelTemplate(tutils.get_default_hparams()) - TrainsLogger.set_bypass_mode(True) - TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008', - files_host='http://integration.trains.allegro.ai:8081', - web_host='http://integration.trains.allegro.ai:8080', ) - logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - train_percent_check=0.05, - logger=logger - ) - result = trainer.fit(model) - - print('result finished') - logger.finalize() - assert result == 1, "Training failed" - - -def test_trains_pickle(tmpdir): - """Verify that pickling trainer with TRAINS logger works.""" - # hparams = tutils.get_default_hparams() - # model = LightningTestModel(hparams) - TrainsLogger.set_bypass_mode(True) - TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008', - files_host='http://integration.trains.allegro.ai:8081', - web_host='http://integration.trains.allegro.ai:8080', ) - logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - logger=logger - ) - pkl_bytes = pickle.dumps(trainer) - trainer2 = pickle.loads(pkl_bytes) - trainer2.logger.log_metrics({"acc": 1.0}) - trainer2.logger.finalize() - logger.finalize() diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 4cd0eff431adc4..0eefb9625ddc74 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -1,27 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import pickle -from unittest.mock import patch +import types +from argparse import ArgumentParser +from unittest import mock + +import pytest from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel + +def get_warnings(recwarn): + warnings_text = '\n'.join(str(w.message) for w in recwarn.list) + recwarn.clear() + return warnings_text -@patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_logger(wandb): + +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_logger_init(wandb, recwarn): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" - logger = WandbLogger(anonymous=True, offline=True) + # test wandb.init called when there is no W&B run + wandb.run = None + logger = WandbLogger() logger.log_metrics({'acc': 1.0}) - wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) + wandb.init.assert_called_once() + wandb.init().log.assert_called_once_with({'acc': 1.0}) + # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() + wandb.init.reset_mock() + wandb.run = wandb.init() + logger = WandbLogger() logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) - - logger.log_hyperparams({'test': None}) - wandb.init().config.update.assert_called_once_with({'test': None}, allow_val_change=True) - + wandb.init.assert_called_once() + wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3}) + + # continue training on same W&B run and offset step + logger.finalize('success') + logger.log_metrics({'acc': 1.0}, step=6) + wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6}) + + # log hyper parameters + logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) + wandb.init().config.update.assert_called_once_with( + { + 'test': 'None', + 'nested/a': 1, + 'b': [2, 3, 4] + }, + allow_val_change=True, + ) + + # watch a model logger.watch('model', 'log', 10) wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10) @@ -29,22 +76,34 @@ def test_wandb_logger(wandb): assert logger.version == wandb.init().id -@patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_pickle(wandb): - """Verify that pickling trainer with wandb logger works. - +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_pickle(wandb, tmpdir): + """ + Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. """ + class Experiment: + """ """ id = 'the_id' + step = 0 + dir = 'wandb' - wandb.init.return_value = Experiment() + def project_name(self): + return 'the_project_name' + wandb.run = None + wandb.init.return_value = Experiment() logger = WandbLogger(id='the_id', offline=True) - trainer = Trainer(max_epochs=1, logger=logger) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=logger, + ) # Access the experiment to ensure it's created assert trainer.logger.experiment, 'missing experiment' + assert trainer.log_dir == logger.save_dir pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) @@ -57,3 +116,74 @@ class Experiment: assert wandb.init.call_args[1]['id'] == 'the_id' del os.environ['WANDB_MODE'] + + +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_logger_dirs_creation(wandb, tmpdir): + """ Test that the logger creates the folders and files in the right place. """ + logger = WandbLogger(save_dir=str(tmpdir), offline=True) + assert logger.version is None + assert logger.name is None + + # mock return values of experiment + wandb.run = None + wandb.init().step = 0 + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + logger.experiment.step = 0 + + for _ in range(2): + _ = logger.experiment + + assert logger.version == '1' + assert logger.name == 'project' + assert str(tmpdir) == logger.save_dir + assert not os.listdir(tmpdir) + + version = logger.version + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3) + assert trainer.log_dir == logger.save_dir + trainer.fit(model) + + assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=2.ckpt'} + assert trainer.log_dir == logger.save_dir + + +def test_wandb_sanitize_callable_params(tmpdir): + """ + Callback function are not serializiable. Therefore, we get them a chance to return + something and if the returned type is not accepted, return None. + """ + opt = "--max_epochs 1".split(" ") + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parent_parser=parser) + params = parser.parse_args(opt) + + def return_something(): + return "something" + + params.something = return_something + + def wrapper_something(): + return return_something + + params.wrapper_something_wo_name = lambda: lambda: '1' + params.wrapper_something = wrapper_something + + assert isinstance(params.gpus, types.FunctionType) + params = WandbLogger._convert_params(params) + params = WandbLogger._flatten_dict(params) + params = WandbLogger._sanitize_callable_params(params) + assert params["gpus"] == '_gpus_arg_default' + assert params["something"] == "something" + assert params["wrapper_something"] == "wrapper_something" + assert params["wrapper_something_wo_name"] == "" + + +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_logger_offline_log_model(wandb, tmpdir): + """ Test that log_model=True raises an error in offline mode """ + with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): + _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py new file mode 100644 index 00000000000000..e52e39cb164880 --- /dev/null +++ b/tests/metrics/test_metric_lightning.py @@ -0,0 +1,190 @@ +import torch +from torchmetrics import Metric as TMetric + +from pytorch_lightning import Trainer +from pytorch_lightning.metrics import Metric as PLMetric +from pytorch_lightning.metrics import MetricCollection +from tests.helpers.boring_model import BoringModel + + +class SumMetric(TMetric): + + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + +class DiffMetric(PLMetric): + + def __init__(self): + super().__init__() + self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, x): + self.x -= x + + def compute(self): + return self.x + + +def test_metric_lightning(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.metric = SumMetric() + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric(x.sum()) + self.sum += x.sum() + + return self.step(x) + + def training_epoch_end(self, outs): + assert torch.allclose(self.sum, self.metric.compute()) + self.sum = 0.0 + self.metric.reset() + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + +def test_metric_lightning_log(tmpdir): + """ Test logging a metric object and that the metric state gets reset after each epoch.""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.metric_step = SumMetric() + self.metric_epoch = SumMetric() + self.sum = 0.0 + + def on_epoch_start(self): + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric_step(x.sum()) + self.sum += x.sum() + self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) + return {'loss': self.step(x), 'data': x} + + def training_epoch_end(self, outs): + self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) + + +def test_scriptable(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + # the metric is not used in the module's `forward` + # so the module should be exportable to TorchScript + self.metric = SumMetric() + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric(x.sum()) + self.sum += x.sum() + self.log("sum", self.metric, on_epoch=True, on_step=False) + return self.step(x) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + rand_input = torch.randn(10, 32) + + script_model = model.to_torchscript() + + # test that we can still do inference + output = model(rand_input) + script_output = script_model(rand_input) + assert torch.allclose(output, script_output) + + +def test_metric_collection_lightning_log(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.metric = MetricCollection([SumMetric(), DiffMetric()]) + self.sum = 0.0 + self.diff = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + metric_vals = self.metric(x.sum()) + self.sum += x.sum() + self.diff -= x.sum() + self.log_dict({f'{k}_step': v for k, v in metric_vals.items()}) + return self.step(x) + + def training_epoch_end(self, outputs): + metric_vals = self.metric.compute() + self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()}) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) + assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py new file mode 100644 index 00000000000000..d3703bf3691c95 --- /dev/null +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -0,0 +1,348 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.5.0""" + +import pytest +import torch + +from pytorch_lightning.metrics import ( + Accuracy, + AUC, + AUROC, + AveragePrecision, + ConfusionMatrix, + ExplainedVariance, + F1, + FBeta, + HammingDistance, + IoU, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, + MetricCollection, + Precision, + PrecisionRecallCurve, + PSNR, + R2Score, + Recall, + ROC, + SSIM, + StatScores, +) +from pytorch_lightning.metrics.functional import ( + auc, + auroc, + average_precision, + bleu_score, + confusion_matrix, + embedding_similarity, + explained_variance, + f1, + fbeta, + hamming_distance, + iou, + mean_absolute_error, + mean_squared_error, + mean_squared_log_error, + precision, + precision_recall, + precision_recall_curve, + psnr, + r2score, + recall, + roc, + ssim, + stat_scores, +) +from pytorch_lightning.metrics.functional.accuracy import accuracy +from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error +from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot + + +def test_v1_5_metrics_utils(): + x = torch.tensor([1, 2, 3]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int)) + + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert get_num_classes(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 0])) == 4 + + x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(select_topk(x, topk=2), torch.Tensor([[0, 1, 1], [1, 1, 0]]).to(torch.int32)) + + x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) + + +def test_v1_5_metrics_collection(): + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + + MetricCollection.__init__._warned = False + with pytest.deprecated_call(match="It will be removed in v1.5.0."): + metrics = MetricCollection([Accuracy()]) + assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)} + + +def test_v1_5_metric_accuracy(): + accuracy._warned = False + + preds = torch.tensor([0, 0, 1, 0, 1]) + target = torch.tensor([0, 0, 1, 1, 1]) + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert accuracy(preds, target) == torch.tensor(0.8) + + Accuracy.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Accuracy() + + +def test_v1_5_metric_auc_auroc(): + AUC.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AUC() + + ROC.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ROC() + + AUROC.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AUROC() + + x = torch.tensor([0, 1, 2, 3]) + y = torch.tensor([0, 1, 2, 2]) + auc._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert auc(x, y) == torch.tensor(4.) + + preds = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 1, 1]) + roc._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + fpr, tpr, thrs = roc(preds, target, pos_label=1) + assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) + assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4) + assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0])) + + preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) + target = torch.tensor([0, 0, 1, 1, 1]) + auroc._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert auroc(preds, target) == torch.tensor(0.5) + + +def test_v1_5_metric_precision_recall(): + AveragePrecision.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AveragePrecision() + + Precision.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Precision() + + Recall.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Recall() + + PrecisionRecallCurve.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PrecisionRecallCurve() + + pred = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 1, 1]) + average_precision._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert average_precision(pred, target) == torch.tensor(1.) + + precision._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert precision(pred, target) == torch.tensor(0.5) + + recall._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert recall(pred, target) == torch.tensor(0.5) + + precision_recall._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + prec, rc = precision_recall(pred, target) + assert prec == torch.tensor(0.5) + assert rc == torch.tensor(0.5) + + precision_recall_curve._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + prec, rc, thrs = precision_recall_curve(pred, target) + assert torch.equal(prec, torch.tensor([1., 1., 1., 1.])) + assert torch.allclose(rc, torch.tensor([1., 0.6667, 0.3333, 0.]), atol=1e-4) + assert torch.equal(thrs, torch.tensor([1, 2, 3])) + + +def test_v1_5_metric_classif_mix(): + ConfusionMatrix.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ConfusionMatrix(num_classes=1) + + FBeta.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + FBeta(num_classes=1) + + F1.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + F1(num_classes=1) + + HammingDistance.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + HammingDistance() + + StatScores.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + StatScores() + + target = torch.tensor([1, 1, 0, 0]) + preds = torch.tensor([0, 1, 0, 0]) + confusion_matrix._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]])) + + target = torch.tensor([0, 1, 2, 0, 1, 2]) + preds = torch.tensor([0, 2, 1, 0, 0, 1]) + fbeta._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5), torch.tensor(0.3333), atol=1e-4) + + f1._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.allclose(f1(preds, target, num_classes=3), torch.tensor(0.3333), atol=1e-4) + + target = torch.tensor([[0, 1], [1, 1]]) + preds = torch.tensor([[0, 1], [0, 1]]) + hamming_distance._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert hamming_distance(preds, target) == torch.tensor(0.25) + + preds = torch.tensor([1, 0, 2, 1]) + target = torch.tensor([1, 1, 2, 0]) + stat_scores._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert torch.equal(stat_scores(preds, target, reduce='micro'), torch.tensor([2, 2, 6, 2, 4])) + + +def test_v1_5_metric_detect(): + IoU.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + IoU(num_classes=1) + + target = torch.randint(0, 2, (10, 25, 25)) + preds = torch.tensor(target) + preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] + iou._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = iou(preds, target) + assert torch.allclose(res, torch.tensor(0.9660), atol=1e-4) + + +def test_v1_5_metric_regress(): + ExplainedVariance.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ExplainedVariance() + + MeanAbsoluteError.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanAbsoluteError() + + MeanSquaredError.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredError() + + MeanSquaredLogError.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredLogError() + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + explained_variance._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = explained_variance(preds, target) + assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) + + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 2]) + mean_absolute_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_absolute_error(x, y) == 0.25 + + mean_relative_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_relative_error(x, y) == 0.125 + + mean_squared_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_squared_error(x, y) == 0.25 + + mean_squared_log_error._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = mean_squared_log_error(x, y) + assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) + + PSNR.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PSNR() + + R2Score.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + R2Score() + + SSIM.__init__._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + SSIM() + + preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + psnr._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = psnr(preds, target) + assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + r2score._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = r2score(preds, target) + assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) + + preds = torch.rand([16, 1, 16, 16]) + target = preds * 0.75 + ssim._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = ssim(preds, target) + assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4) + + +def test_v1_5_metric_others(): + translate_corpus = ['the cat is on the mat'.split()] + reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + bleu_score._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = bleu_score(translate_corpus, reference_corpus) + assert torch.allclose(res, torch.tensor(0.7598), atol=1e-4) + + embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) + embedding_similarity._warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = embedding_similarity(embeddings) + assert torch.allclose( + res, torch.tensor([[0.0000, 1.0000, 0.9759], [1.0000, 0.0000, 0.9759], [0.9759, 0.9759, 0.0000]]), atol=1e-4 + ) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py new file mode 100644 index 00000000000000..f1f17d0624936f --- /dev/null +++ b/tests/metrics/utils.py @@ -0,0 +1,270 @@ +import os +import pickle +import sys +from functools import partial +from typing import Callable + +import numpy as np +import pytest +import torch +from torch.multiprocessing import Pool, set_start_method +from torchmetrics import Metric + +try: + set_start_method("spawn") +except RuntimeError: + pass + +NUM_PROCESSES = 2 +NUM_BATCHES = 10 +BATCH_SIZE = 32 +NUM_CLASSES = 5 +EXTRA_DIM = 3 +THRESHOLD = 0.5 + + +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" + + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _assert_allclose(pl_result, sk_result, atol: float = 1e-8): + """ Utility function for recursively asserting that two results are within + a certain tolerance + """ + # single output compare + if isinstance(pl_result, torch.Tensor): + assert np.allclose(pl_result.numpy(), sk_result, atol=atol, equal_nan=True) + # multi output compare + elif isinstance(pl_result, (tuple, list)): + for pl_res, sk_res in zip(pl_result, sk_result): + _assert_allclose(pl_res, sk_res, atol=atol) + else: + raise ValueError('Unknown format for comparison') + + +def _assert_tensor(pl_result): + """ Utility function for recursively checking that some input only consist of + torch tensors + """ + if isinstance(pl_result, (list, tuple)): + for plr in pl_result: + _assert_tensor(plr) + else: + assert isinstance(pl_result, torch.Tensor) + + +def _class_test( + rank: int, + worldsize: int, + preds: torch.Tensor, + target: torch.Tensor, + metric_class: Metric, + sk_metric: Callable, + dist_sync_on_step: bool, + metric_args: dict = {}, + check_dist_sync_on_step: bool = True, + check_batch: bool = True, + atol: float = 1e-8, +): + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) + """ + # Instanciate lightning metric + metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) + + # verify metrics work after being loaded from pickled state + pickled_metric = pickle.dumps(metric) + metric = pickle.loads(pickled_metric) + + for i in range(rank, NUM_BATCHES, worldsize): + batch_result = metric(preds[i], target[i]) + + if metric.dist_sync_on_step: + if rank == 0: + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) + sk_batch_result = sk_metric(ddp_preds, ddp_target) + # assert for dist_sync_on_step + if check_dist_sync_on_step: + _assert_allclose(batch_result, sk_batch_result, atol=atol) + else: + sk_batch_result = sk_metric(preds[i], target[i]) + # assert for batch + if check_batch: + _assert_allclose(batch_result, sk_batch_result, atol=atol) + + # check on all batches on all ranks + result = metric.compute() + _assert_tensor(result) + + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) + sk_result = sk_metric(total_preds, total_target) + + # assert after aggregation + _assert_allclose(result, sk_result, atol=atol) + + +def _functional_test( + preds: torch.Tensor, + target: torch.Tensor, + metric_functional: Callable, + sk_metric: Callable, + metric_args: dict = {}, + atol: float = 1e-8, +): + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization + """ + metric = partial(metric_functional, **metric_args) + + for i in range(NUM_BATCHES): + lightning_result = metric(preds[i], target[i]) + sk_result = sk_metric(preds[i], target[i]) + + # assert its the same + _assert_allclose(lightning_result, sk_result, atol=atol) + + +class MetricTester: + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. + + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. + """ + + atol = 1e-8 + + def setup_class(self): + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp + """ + + self.poolSize = NUM_PROCESSES + self.pool = Pool(processes=self.poolSize) + self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)]) + + def teardown_class(self): + """ Close pool of workers """ + self.pool.close() + self.pool.join() + + def run_functional_metric_test( + self, + preds: torch.Tensor, + target: torch.Tensor, + metric_functional: Callable, + sk_metric: Callable, + metric_args: dict = {}, + ): + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization + """ + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) + + def run_class_metric_test( + self, + ddp: bool, + preds: torch.Tensor, + target: torch.Tensor, + metric_class: Metric, + sk_metric: Callable, + dist_sync_on_step: bool, + metric_args: dict = {}, + check_dist_sync_on_step: bool = True, + check_batch: bool = True, + ): + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) + """ + if ddp: + if sys.platform == "win32": + pytest.skip("DDP not supported on windows") + + self.pool.starmap( + partial( + _class_test, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + check_dist_sync_on_step=check_dist_sync_on_step, + check_batch=check_batch, + atol=self.atol, + ), + [(rank, self.poolSize) for rank in range(self.poolSize)], + ) + else: + _class_test( + 0, + 1, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + check_dist_sync_on_step=check_dist_sync_on_step, + check_batch=check_batch, + atol=self.atol, + ) diff --git a/tests/mnode_tests.txt b/tests/mnode_tests.txt new file mode 100644 index 00000000000000..77a3ed58db526f --- /dev/null +++ b/tests/mnode_tests.txt @@ -0,0 +1,2 @@ +./tests/backends/test_multi_nodes_gpu.py::test_logging_sync_dist_true_ddp +./tests/backends/test_multi_nodes_gpu.py::test__validation_step__log diff --git a/tests/models/conf/config.yaml b/tests/models/conf/config.yaml new file mode 100644 index 00000000000000..faf751c24f6cb3 --- /dev/null +++ b/tests/models/conf/config.yaml @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +defaults: + - training: default + +log: ${training.log} diff --git a/tests/models/conf/training/default.yaml b/tests/models/conf/training/default.yaml new file mode 100644 index 00000000000000..2c35b223654203 --- /dev/null +++ b/tests/models/conf/training/default.yaml @@ -0,0 +1,2 @@ +# @package training +log: "Something" diff --git a/tests/models/data/__init__.py b/tests/models/data/__init__.py new file mode 100644 index 00000000000000..6b0006f6b25003 --- /dev/null +++ b/tests/models/data/__init__.py @@ -0,0 +1 @@ +# this is needed only for mypy==0.800 as it undestands only packages diff --git a/tests/models/data/horovod/__init__.py b/tests/models/data/horovod/__init__.py new file mode 100644 index 00000000000000..6b0006f6b25003 --- /dev/null +++ b/tests/models/data/horovod/__init__.py @@ -0,0 +1 @@ +# this is needed only for mypy==0.800 as it undestands only packages diff --git a/tests/models/data/horovod/test_train_script.py b/tests/models/data/horovod/test_train_script.py new file mode 100644 index 00000000000000..ee77efeeb8675e --- /dev/null +++ b/tests/models/data/horovod/test_train_script.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.models.data.horovod.train_default_model import run_test_from_config + + +def test_horovod_model_script(tmpdir): + """This just for testing/debugging horovod script without horovod...""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + deterministic=True, + ) + run_test_from_config(trainer_options, check_size=False, on_gpu=False) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 5b9a08c6bd420a..46ab64afccb038 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,41 +21,81 @@ import os import sys -import horovod.torch as hvd +import torch -PATH_HERE = os.path.abspath(os.path.dirname(__file__)) -PATH_ROOT = os.path.join(PATH_HERE, '..', '..', '..', '..') -sys.path.insert(0, os.path.abspath(PATH_ROOT)) +# this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do +PYTHONPATH = os.getenv('PYTHONPATH', '') +if ':' in PYTHONPATH: + sys.path = PYTHONPATH.split(':') + sys.path from pytorch_lightning import Trainer # noqa: E402 from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 -from tests.base import EvalModelTemplate # noqa: E402 -from tests.base.utils import set_random_master_port, get_default_hparams, run_model_test # noqa: E402 +from pytorch_lightning.trainer.states import TrainerState # noqa: E402 +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE # noqa: E402 +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd # noqa: E402 +else: + print('You requested to import Horovod which is missing or not supported for your OS.') + +from tests.helpers import BoringModel # noqa: E402 +from tests.helpers.utils import reset_seed, set_random_master_port # noqa: E402 parser = argparse.ArgumentParser() parser.add_argument('--trainer-options', required=True) parser.add_argument('--on-gpu', action='store_true', default=False) -def run_test_from_config(trainer_options): +def run_test_from_config(trainer_options, on_gpu, check_size=True): """Trains the default model with the given config.""" set_random_master_port() + reset_seed() + + ckpt_path = trainer_options['weights_save_path'] + trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)]) - ckpt_path = trainer_options['default_root_dir'] - trainer_options.update(checkpoint_callback=ModelCheckpoint(ckpt_path)) + class TestModel(BoringModel): - model = EvalModelTemplate(get_default_hparams()) - run_model_test(trainer_options, model, on_gpu=args.on_gpu, version=0, with_hpc=False) + def training_epoch_end(self, outputs) -> None: + res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") + assert res.sum() == self.trainer.training_type_plugin.world_size + + model = TestModel() + trainer = Trainer(**trainer_options) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # Horovod should be initialized following training. If not, this will raise an exception. - assert hvd.size() == 2 + if check_size: + assert hvd.size() == 2 + + if trainer.global_rank > 0: + return + + # test model loading + pretrained_model = BoringModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # test new model accuracy + test_loaders = model.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + for dataloader in test_loaders: + batch = next(iter(dataloader)) + pretrained_model(batch) + + # test HPC saving + trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) + # test HPC loading + checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) + trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) - if args.on_gpu: + if on_gpu: + trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) # Test the root_gpu property - assert Trainer(gpus=1, distributed_backend='horovod', max_epochs=1).root_gpu == hvd.local_rank() + assert trainer.root_gpu == hvd.local_rank() if __name__ == "__main__": args = parser.parse_args() - run_test_from_config(json.loads(args.trainer_options)) + run_test_from_config(json.loads(args.trainer_options), args.on_gpu) diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 52fb90f135bae0..0b9d6776c1aaaa 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -1,68 +1,161 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os +from unittest import mock import pytest import torch +from torch import optim +from torch.utils.data import DataLoader -import tests.base.utils as tutils +import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf -@pytest.mark.spawn -@pytest.mark.parametrize("backend", ['dp', 'ddp']) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_amp_single_gpu(tmpdir, backend): +class AMPTestModel(BoringModel): + + def _step(self, batch, batch_idx): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + loss = self.loss(batch, output) + return loss + + def training_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch, batch_idx) + return {"y": output} + + def predict(self, batch, batch_idx, dataloader_idx=None): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + return output + + +@pytest.mark.skip(reason='dp + amp not supported currently') # TODO +@RunIf(min_gpus=1) +def test_amp_single_gpu_dp(tmpdir): """Make sure DP/DDP + AMP work.""" tutils.reset_seed() + trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, gpus=1, - distributed_backend=backend, - precision=16 + accelerator='dp', + precision=16, ) - model = EvalModelTemplate(tutils.get_default_hparams()) + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) - result = trainer.fit(model) + trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) - assert result == 1 + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -@pytest.mark.spawn -@pytest.mark.parametrize("backend", ['dp', 'ddp']) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_amp_multi_gpu(tmpdir, backend): +@RunIf(min_gpus=1) +def test_amp_single_gpu_ddp_spawn(tmpdir): """Make sure DP/DDP + AMP work.""" - tutils.set_random_master_port() + tutils.reset_seed() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gpus=1, + accelerator='ddp_spawn', + precision=16, + ) - model = EvalModelTemplate(tutils.get_default_hparams()) + model = AMPTestModel() + # tutils.run_model_test(trainer_options, model) + trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - trainer_options = dict( + +@pytest.mark.skip(reason='dp + amp not supported currently') # TODO +@RunIf(min_gpus=1) +def test_amp_multi_gpu_dp(tmpdir): + """Make sure DP/DDP + AMP work.""" + tutils.reset_seed() + + trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - # gpus=2, - gpus='0, 1', # test init with gpu string - distributed_backend=backend, - precision=16 + gpus=2, + accelerator='dp', + precision=16, ) + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@RunIf(min_gpus=2) +def test_amp_multi_gpu_ddp_spawn(tmpdir): + """Make sure DP/DDP + AMP work.""" + tutils.reset_seed() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gpus=2, + accelerator='ddp_spawn', + precision=16, + ) -@pytest.mark.spawn -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") + model = AMPTestModel() + # tutils.run_model_test(trainer_options, model) + trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@RunIf(min_gpus=2) +@mock.patch.dict( + os.environ, { + "SLURM_NTASKS": "1", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" + } +) def test_amp_gpu_ddp_slurm_managed(tmpdir): """Make sure DDP + AMP work.""" # simulate setting slurm flags tutils.set_random_master_port() - os.environ['SLURM_LOCALID'] = str(0) - model = EvalModelTemplate(tutils.get_default_hparams()) + model = AMPTestModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) @@ -72,38 +165,88 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # fit model trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, gpus=[0], - distributed_backend='ddp', + accelerator='ddp_spawn', precision=16, - checkpoint_callback=checkpoint, + callbacks=[checkpoint], logger=logger, ) - trainer.is_slurm_managing_tasks = True - result = trainer.fit(model) + _ = trainer.fit(model) # correct result and ok accuracy - assert result == 1, 'amp + ddp model failed to complete' + assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete' # test root model address - assert trainer.resolve_root_node_address('abc') == 'abc' - assert trainer.resolve_root_node_address('abc[23]') == 'abc23' - assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23' - assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23' + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc') == 'abc' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23]') == 'abc23' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23' + generated = trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]') + assert generated == 'abc23' +@pytest.mark.skipif(torch.cuda.is_available(), reason="test is restricted only on CPU") def test_cpu_model_with_amp(tmpdir): """Make sure model trains on CPU.""" - trainer_options = dict( + with pytest.raises(MisconfigurationException, match="AMP is only available on GPU"): + Trainer(precision=16) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_amp_without_apex(tmpdir): + """Check that even with apex amp type without requesting precision=16 the amp backend is void.""" + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + amp_backend='native', + ) + assert trainer.amp_backend is None + + trainer = Trainer( default_root_dir=tmpdir, - progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.4, - precision=16 + amp_backend='apex', ) + assert trainer.amp_backend is None + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.dev_debugger.count_events('AMP') == 0 + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=1, amp_apex=True) +def test_amp_with_apex(tmpdir): + """Check calling apex scaling in training.""" + + class CustomModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=0.01) + optimizer2 = optim.SGD(self.parameters(), lr=0.01) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] - model = EvalModelTemplate(tutils.get_default_hparams()) + model = CustomModel() + model.training_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=5, + precision=16, + amp_backend='apex', + gpus=1, + ) + assert str(trainer.amp_backend) == "AMPType.APEX" + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.dev_debugger.count_events('AMP') == 10 - with pytest.raises((MisconfigurationException, ModuleNotFoundError)): - tutils.run_model_test(trainer_options, model, on_gpu=False) + assert isinstance(trainer.lr_schedulers[0]['scheduler'].optimizer, optim.Adam) + assert isinstance(trainer.lr_schedulers[1]['scheduler'].optimizer, optim.SGD) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 13120c01756c10..98355a2d10c606 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -1,43 +1,131 @@ -from collections import namedtuple -import platform +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os -import pytest import torch -from packaging.version import parse as version_parse -import tests.base.utils as tutils +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import EarlyStopping -from tests.base import EvalModelTemplate +from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + + +def test_cpu_slurm_save_load(tmpdir): + """Verify model save/load/checkpoint on CPU.""" + model = BoringModel() + + # logger file to get meta + logger = tutils.get_default_logger(tmpdir) + version = logger.version + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=logger, + limit_train_batches=0.2, + limit_val_batches=0.2, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + ) + trainer.fit(model) + real_global_step = trainer.global_step + + # traning complete + assert trainer.state == TrainerState.FINISHED, 'cpu model failed to complete' + + # predict with trained model before saving + # make a prediction + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: + for batch in dataloader: + break + + model.eval() + pred_before_saving = model(batch) + + # test HPC saving + # simulate snapshot on slurm + saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger) + assert os.path.exists(saved_filepath) + + # new logger file to get meta + logger = tutils.get_default_logger(tmpdir, version=version) + + model = BoringModel() + + class _StartCallback(Callback): + # set the epoch start hook so we can predict before the model does the full training + def on_train_epoch_start(self, trainer, model): + assert trainer.global_step == real_global_step and trainer.global_step > 0 + # predict with loaded model to make sure answers are the same + mode = model.training + model.eval() + new_pred = model(batch) + assert torch.eq(pred_before_saving, new_pred).all() + model.train(mode) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=logger, + callbacks=[_StartCallback(), ModelCheckpoint(dirpath=tmpdir)], + ) + # by calling fit again, we trigger training, loading weights from the cluster + # and our hook to predict using current model before any more weight updates + trainer.fit(model) def test_early_stopping_cpu_model(tmpdir): - """Test each of the trainer options.""" - stopping = EarlyStopping(monitor='val_loss', min_delta=0.1) + + class ModelTrainVal(BoringModel): + + def validation_step(self, *args, **kwargs): + output = super().validation_step(*args, **kwargs) + self.log('val_loss', output['x']) + return output + + tutils.reset_seed() + stopping = EarlyStopping(monitor="val_loss", min_delta=0.1) trainer_options = dict( + callbacks=[stopping], default_root_dir=tmpdir, - early_stop_callback=stopping, gradient_clip_val=1.0, - overfit_pct=0.20, + overfit_batches=0.20, track_grad_norm=2, - train_percent_check=0.1, - val_percent_check=0.1, + progress_bar_refresh_rate=0, + accumulate_grad_batches=2, + limit_train_batches=0.1, + limit_val_batches=0.1, ) - model = EvalModelTemplate(tutils.get_default_hparams()) - tutils.run_model_test(trainer_options, model, on_gpu=False) + model = ModelTrainVal() + tpipes.run_model_test(trainer_options, model, on_gpu=False) # test freeze on cpu model.freeze() model.unfreeze() -@pytest.mark.spawn -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif((platform.system() == "Darwin" and - version_parse(torch.__version__) < version_parse("1.3.0")), - reason="Distributed training is not supported on MacOS before Torch 1.3.0") +@RunIf(skip_windows=True) def test_multi_cpu_model_ddp(tmpdir): """Make sure DDP works.""" tutils.set_random_master_port() @@ -46,34 +134,40 @@ def test_multi_cpu_model_ddp(tmpdir): default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, + limit_train_batches=0.4, + limit_val_batches=0.2, gpus=None, num_processes=2, - distributed_backend='ddp_cpu' + accelerator='ddp_cpu', ) - model = EvalModelTemplate(tutils.get_default_hparams()) - tutils.run_model_test(trainer_options, model, on_gpu=False) + dm = ClassifDataModule() + model = ClassificationModel() + tpipes.run_model_test(trainer_options, model, data=dm, on_gpu=False) def test_lbfgs_cpu_model(tmpdir): - """Test each of the trainer options.""" + """Test each of the trainer options. Testing LBFGS optimizer""" + + class ModelSpecifiedOptimizer(BoringModel): + + def __init__(self, optimizer_name, learning_rate): + super().__init__() + self.optimizer_name = optimizer_name + self.learning_rate = learning_rate + self.save_hyperparameters() + trainer_options = dict( default_root_dir=tmpdir, - max_epochs=2, + max_epochs=1, progress_bar_refresh_rate=0, - weights_summary='top', - train_percent_check=1.0, - val_percent_check=0.2, + weights_summary="top", + limit_train_batches=0.2, + limit_val_batches=0.2, ) - hparams = tutils.get_default_hparams() - setattr(hparams, 'optimizer_name', 'lbfgs') - setattr(hparams, 'learning_rate', 0.002) - model = EvalModelTemplate(hparams) - model.configure_optimizers = model.configure_optimizers__lbfgs - tutils.run_model_test_without_loggers(trainer_options, model, min_acc=0.5) + model = ModelSpecifiedOptimizer(optimizer_name="LBFGS", learning_rate=0.004) + tpipes.run_model_test_without_loggers(trainer_options, model, min_acc=0.01) def test_default_logger_callbacks_cpu_model(tmpdir): @@ -82,14 +176,14 @@ def test_default_logger_callbacks_cpu_model(tmpdir): default_root_dir=tmpdir, max_epochs=1, gradient_clip_val=1.0, - overfit_pct=0.20, + overfit_batches=0.20, progress_bar_refresh_rate=0, - train_percent_check=0.01, - val_percent_check=0.01, + limit_train_batches=0.01, + limit_val_batches=0.01, ) - model = EvalModelTemplate(tutils.get_default_hparams()) - tutils.run_model_test_without_loggers(trainer_options, model) + model = BoringModel() + tpipes.run_model_test_without_loggers(trainer_options, model, min_acc=0.01) # test freeze on cpu model.freeze() @@ -98,7 +192,20 @@ def test_default_logger_callbacks_cpu_model(tmpdir): def test_running_test_after_fitting(tmpdir): """Verify test() on fitted model.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + + class ModelTrainValTest(BoringModel): + + def validation_step(self, *args, **kwargs): + output = super().validation_step(*args, **kwargs) + self.log('val_loss', output['x']) + return output + + def test_step(self, *args, **kwargs): + output = super().test_step(*args, **kwargs) + self.log('test_loss', output['y']) + return output + + model = ModelTrainValTest() # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -110,26 +217,38 @@ def test_running_test_after_fitting(tmpdir): trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, - max_epochs=8, - train_percent_check=0.4, - val_percent_check=0.2, - test_percent_check=0.2, - checkpoint_callback=checkpoint, - logger=logger + max_epochs=2, + limit_train_batches=0.4, + limit_val_batches=0.2, + limit_test_batches=0.2, + callbacks=[checkpoint], + logger=logger, ) - result = trainer.fit(model) + trainer.fit(model) - assert result == 1, 'training failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() # test we have good test accuracy - tutils.assert_ok_model_acc(trainer, thr=0.5) + tutils.assert_ok_model_acc(trainer, key='test_loss', thr=0.5) def test_running_test_no_val(tmpdir): - """Verify `test()` works on a model with no `val_loader`.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + """Verify `test()` works on a model with no `val_dataloader`. It performs + train and test only""" + + class ModelTrainTest(BoringModel): + + def val_dataloader(self): + pass + + def test_step(self, *args, **kwargs): + output = super().test_step(*args, **kwargs) + self.log('test_loss', output['y']) + return output + + model = ModelTrainTest() # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -139,102 +258,50 @@ def test_running_test_no_val(tmpdir): # fit model trainer = Trainer( + default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - test_percent_check=0.2, - checkpoint_callback=checkpoint, + limit_train_batches=0.4, + limit_val_batches=0.2, + limit_test_batches=0.2, + callbacks=[checkpoint], logger=logger, - early_stop_callback=False ) - result = trainer.fit(model) + trainer.fit(model) - assert result == 1, 'training failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() # test we have good test accuracy - tutils.assert_ok_model_acc(trainer) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_single_gpu_batch_parse(): - trainer = Trainer() - - # batch is just a tensor - batch = torch.rand(2, 3) - batch = trainer.transfer_batch_to_gpu(batch, 0) - assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor' - - # tensor list - batch = [torch.rand(2, 3), torch.rand(2, 3)] - batch = trainer.transfer_batch_to_gpu(batch, 0) - assert batch[0].device.index == 0 and batch[0].type() == 'torch.cuda.FloatTensor' - assert batch[1].device.index == 0 and batch[1].type() == 'torch.cuda.FloatTensor' - - # tensor list of lists - batch = [[torch.rand(2, 3), torch.rand(2, 3)]] - batch = trainer.transfer_batch_to_gpu(batch, 0) - assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor' - assert batch[0][1].device.index == 0 and batch[0][1].type() == 'torch.cuda.FloatTensor' - - # tensor dict - batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}] - batch = trainer.transfer_batch_to_gpu(batch, 0) - assert batch[0]['a'].device.index == 0 and batch[0]['a'].type() == 'torch.cuda.FloatTensor' - assert batch[0]['b'].device.index == 0 and batch[0]['b'].type() == 'torch.cuda.FloatTensor' - - # tuple of tensor list and list of tensor dict - batch = ([torch.rand(2, 3) for _ in range(2)], - [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)} for _ in range(2)]) - batch = trainer.transfer_batch_to_gpu(batch, 0) - assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor' - - assert batch[1][0]['a'].device.index == 0 - assert batch[1][0]['a'].type() == 'torch.cuda.FloatTensor' - - assert batch[1][0]['b'].device.index == 0 - assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor' - - # namedtuple of tensor - BatchType = namedtuple('BatchType', ['a', 'b']) - batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] - batch = trainer.transfer_batch_to_gpu(batch, 0) - assert batch[0].a.device.index == 0 - assert batch[0].a.type() == 'torch.cuda.FloatTensor' + tutils.assert_ok_model_acc(trainer, key='test_loss') def test_simple_cpu(tmpdir): """Verify continue training session on CPU.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = BoringModel() # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.1, + limit_val_batches=0.1, + limit_train_batches=20, ) - result = trainer.fit(model) + trainer.fit(model) # traning complete - assert result == 1, 'amp + ddp model failed to complete' + assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete' def test_cpu_model(tmpdir): """Make sure model trains on CPU.""" trainer_options = dict( - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.4 + default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, limit_train_batches=4, limit_val_batches=4 ) - model = EvalModelTemplate(tutils.get_default_hparams()) - - tutils.run_model_test(trainer_options, model, on_gpu=False) + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False) def test_all_features_cpu_model(tmpdir): @@ -242,17 +309,18 @@ def test_all_features_cpu_model(tmpdir): trainer_options = dict( default_root_dir=tmpdir, gradient_clip_val=1.0, - overfit_pct=0.20, + overfit_batches=0.20, track_grad_norm=2, progress_bar_refresh_rate=0, accumulate_grad_batches=2, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.4 + limit_train_batches=0.4, + limit_val_batches=0.4, ) - model = EvalModelTemplate(tutils.get_default_hparams()) - tutils.run_model_test(trainer_options, model, on_gpu=False) + model = BoringModel() + + tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.01) def test_tbptt_cpu_model(tmpdir): @@ -265,16 +333,20 @@ def test_tbptt_cpu_model(tmpdir): y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() class MockSeq2SeqDataset(torch.utils.data.Dataset): + def __getitem__(self, i): return x_seq, y_seq_list def __len__(self): return 1 - class BpttTestModel(EvalModelTemplate): - def __init__(self, hparams): - super().__init__(hparams) + class BpttTestModel(BoringModel): + + def __init__(self, batch_size, in_features, out_features, *args, **kwargs): + super().__init__(*args, **kwargs) self.test_hidden = None + self.batch_size = batch_size + self.layer = torch.nn.Linear(in_features, out_features) def training_step(self, batch, batch_idx, hiddens): assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" @@ -287,13 +359,18 @@ def training_step(self, batch, batch_idx, hiddens): assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss( - pred, y_tensor.view(batch_size, truncated_bptt_steps)) + loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) return { - 'loss': loss_val, - 'hiddens': self.test_hidden, + "loss": loss_val, + "hiddens": self.test_hidden, } + def training_epoch_end(self, training_step_outputs): + training_step_outputs = training_step_outputs[0] + assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps) + loss = torch.stack([x["loss"] for x in training_step_outputs]).mean() + self.log("train_loss", loss) + def train_dataloader(self): return torch.utils.data.DataLoader( dataset=MockSeq2SeqDataset(), @@ -302,39 +379,17 @@ def train_dataloader(self): sampler=None, ) - hparams = tutils.get_default_hparams() - hparams.batch_size = batch_size - hparams.in_features = truncated_bptt_steps - hparams.hidden_dim = truncated_bptt_steps - hparams.out_features = truncated_bptt_steps - - model = BpttTestModel(hparams) + model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) + model.example_input_array = torch.randn(5, truncated_bptt_steps) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, truncated_bptt_steps=truncated_bptt_steps, - val_percent_check=0, + limit_val_batches=0, weights_summary=None, - early_stop_callback=False - ) - result = trainer.fit(model) - - assert result == 1, 'training failed to complete' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_single_gpu_model(tmpdir): - """Make sure single GPU works (DP mode).""" - trainer_options = dict( - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - train_percent_check=0.1, - val_percent_check=0.1, - gpus=1 ) + trainer.fit(model) - model = EvalModelTemplate(tutils.get_default_hparams()) - tutils.run_model_test(trainer_options, model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 5bdb603e145188..7764754594a099 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -1,162 +1,87 @@ -import os +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from unittest.mock import patch import pytest import torch -import tests.base.utils as tutils +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.core import memory -from pytorch_lightning.trainer.distrib_parts import parse_gpu_ids, determine_root_gpu_device +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel PRETEND_N_OF_GPUS = 16 -@pytest.mark.spawn -@pytest.mark.parametrize("backend", ['dp', 'ddp', 'ddp2']) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_model(tmpdir, backend): - """Make sure DDP works.""" +@RunIf(min_gpus=2) +def test_multi_gpu_none_backend(tmpdir): + """Make sure when using multiple GPUs the user can't use `distributed_backend = None`.""" tutils.set_random_master_port() - trainer_options = dict( default_root_dir=tmpdir, + progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - gpus=[0, 1], - distributed_backend=backend, - ) - - model = EvalModelTemplate(tutils.get_default_hparams()) - # tutils.run_model_test(trainer_options, model) - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result - - # test memory helper functions - memory.get_memory_profile('min_max') - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_ddp_all_dataloaders_passed_to_fit(tmpdir): - """Make sure DDP works with dataloaders passed to fit()""" - tutils.set_random_master_port() - - trainer_options = dict(default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - gpus=[0, 1], - distributed_backend='ddp') - - model = EvalModelTemplate(tutils.get_default_hparams()) - fit_options = dict(train_dataloader=model.train_dataloader(), - val_dataloaders=model.val_dataloader()) - - trainer = Trainer(**trainer_options) - result = trainer.fit(model, **fit_options) - assert result == 1, "DDP doesn't work with dataloaders passed to fit()." - - -def test_cpu_slurm_save_load(tmpdir): - """Verify model save/load/checkpoint on CPU.""" - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - - # logger file to get meta - logger = tutils.get_default_logger(tmpdir) - version = logger.version - - # fit model - trainer = Trainer( - max_epochs=1, - logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) - ) - result = trainer.fit(model) - real_global_step = trainer.global_step - - # traning complete - assert result == 1, 'cpu model failed to complete' - - # predict with trained model before saving - # make a prediction - dataloaders = model.test_dataloader() - if not isinstance(dataloaders, list): - dataloaders = [dataloaders] - - for dataloader in dataloaders: - for batch in dataloader: - break - - x, y = batch - x = x.view(x.size(0), -1) - - model.eval() - pred_before_saving = model(x) - - # test HPC saving - # simulate snapshot on slurm - saved_filepath = trainer.hpc_save(tmpdir, logger) - assert os.path.exists(saved_filepath) - - # new logger file to get meta - logger = tutils.get_default_logger(tmpdir, version=version) - - trainer = Trainer( - max_epochs=1, - logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir), + limit_train_batches=0.2, + limit_val_batches=0.2, + gpus=2, ) - model = EvalModelTemplate(hparams) - # set the epoch start hook so we can predict before the model does the full training - def assert_pred_same(): - assert trainer.global_step == real_global_step and trainer.global_step > 0 + dm = ClassifDataModule() + model = ClassificationModel() + tpipes.run_model_test(trainer_options, model, dm) - # predict with loaded model to make sure answers are the same - trainer.model.eval() - new_pred = trainer.model(x) - assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 - model.on_epoch_start = assert_pred_same - - # by calling fit again, we trigger training, loading weights from the cluster - # and our hook to predict using current model before any more weight updates - trainer.fit(model) - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_none_backend(tmpdir): - """Make sure when using multiple GPUs the user can't use `distributed_backend = None`.""" +@RunIf(min_gpus=2) +@pytest.mark.parametrize('gpus', [1, [0], [1]]) +def test_single_gpu_model(tmpdir, gpus): + """Make sure single GPU works (DP mode).""" trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.1, - val_percent_check=0.1, - gpus='-1' + limit_train_batches=0.1, + limit_val_batches=0.1, + gpus=gpus ) - model = EvalModelTemplate(tutils.get_default_hparams()) - with pytest.warns(UserWarning): - tutils.run_model_test(trainer_options, model) + model = BoringModel() + tpipes.run_model_test(trainer_options, model) @pytest.fixture def mocked_device_count(monkeypatch): + def device_count(): return PRETEND_N_OF_GPUS + def is_available(): + return True + + monkeypatch.setattr(torch.cuda, 'is_available', is_available) monkeypatch.setattr(torch.cuda, 'device_count', device_count) @pytest.fixture def mocked_device_count_0(monkeypatch): + def device_count(): return 0 @@ -173,7 +98,7 @@ def device_count(): pytest.param(3, 3, "ddp", id="3rd gpu - 1 gpu to use (backend:ddp)") ]) def test_trainer_gpu_parse(mocked_device_count, gpus, expected_num_gpus, distributed_backend): - assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus + assert Trainer(gpus=gpus, accelerator=distributed_backend).num_gpus == expected_num_gpus @pytest.mark.gpus_param_tests @@ -182,7 +107,7 @@ def test_trainer_gpu_parse(mocked_device_count, gpus, expected_num_gpus, distrib pytest.param(None, 0, "ddp", id="None - expect 0 gpu to use."), ]) def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distributed_backend): - assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus + assert Trainer(gpus=gpus, accelerator=distributed_backend).num_gpus == expected_num_gpus @pytest.mark.gpus_param_tests @@ -195,7 +120,7 @@ def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distr pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)") ]) def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend): - assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu + assert Trainer(gpus=gpus, accelerator=distributed_backend).root_gpu == expected_root_gpu @pytest.mark.gpus_param_tests @@ -205,7 +130,7 @@ def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distrib pytest.param(0, None, "ddp", id="None is None"), ]) def test_root_gpu_property_0_passing(mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): - assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu + assert Trainer(gpus=gpus, accelerator=distributed_backend).root_gpu == expected_root_gpu # Asking for a gpu when non are available will result in a MisconfigurationException @@ -221,7 +146,7 @@ def test_root_gpu_property_0_passing(mocked_device_count_0, gpus, expected_root_ ]) def test_root_gpu_property_0_raising(mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): with pytest.raises(MisconfigurationException): - Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu + Trainer(gpus=gpus, accelerator=distributed_backend) @pytest.mark.gpus_param_tests @@ -233,7 +158,7 @@ def test_root_gpu_property_0_raising(mocked_device_count_0, gpus, expected_root_ pytest.param([1, 2], 1, id="[1, 2] gpus, expect gpu root device to be 1."), ]) def test_determine_root_gpu_device(gpus, expected_root_gpu): - assert determine_root_gpu_device(gpus) == expected_root_gpu + assert device_parser.determine_root_gpu_device(gpus) == expected_root_gpu @pytest.mark.gpus_param_tests @@ -245,6 +170,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu): pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"), pytest.param([0], [0]), pytest.param([1, 3], [1, 3]), + pytest.param((1, 3), [1, 3]), pytest.param('0', [0]), pytest.param('3', [3]), pytest.param('1, 3', [1, 3]), @@ -252,7 +178,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu): pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"), ]) def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): - assert parse_gpu_ids(gpus) == expected_gpu_ids + assert device_parser.parse_gpu_ids(gpus) == expected_gpu_ids @pytest.mark.gpus_param_tests @@ -264,28 +190,139 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): pytest.param([-1]), pytest.param([None]), pytest.param(['0']), - pytest.param((0, 1)), ]) def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus): with pytest.raises(MisconfigurationException): - parse_gpu_ids(gpus) + device_parser.parse_gpu_ids(gpus) @pytest.mark.gpus_param_tests @pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1']) -def test_parse_gpu_fail_on_non_existant_id(mocked_device_count_0, gpus): +def test_parse_gpu_fail_on_non_existent_id(mocked_device_count_0, gpus): with pytest.raises(MisconfigurationException): - parse_gpu_ids(gpus) + device_parser.parse_gpu_ids(gpus) @pytest.mark.gpus_param_tests -def test_parse_gpu_fail_on_non_existant_id_2(mocked_device_count): +def test_parse_gpu_fail_on_non_existent_id_2(mocked_device_count): with pytest.raises(MisconfigurationException): - parse_gpu_ids([1, 2, 19]) + device_parser.parse_gpu_ids([1, 2, 19]) @pytest.mark.gpus_param_tests @pytest.mark.parametrize("gpus", [-1, '-1']) -def test_parse_gpu_returns_None_when_no_devices_are_available(mocked_device_count_0, gpus): +def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_count_0, gpus): with pytest.raises(MisconfigurationException): - parse_gpu_ids(gpus) + device_parser.parse_gpu_ids(gpus) + + +@RunIf(min_gpus=1) +def test_single_gpu_batch_parse(): + trainer = Trainer(gpus=1) + + # non-transferrable types + primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] + for batch in primitive_objects: + data = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert data == batch + + # batch is just a tensor + batch = torch.rand(2, 3) + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor' + + # tensor list + batch = [torch.rand(2, 3), torch.rand(2, 3)] + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert batch[0].device.index == 0 and batch[0].type() == 'torch.cuda.FloatTensor' + assert batch[1].device.index == 0 and batch[1].type() == 'torch.cuda.FloatTensor' + + # tensor list of lists + batch = [[torch.rand(2, 3), torch.rand(2, 3)]] + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor' + assert batch[0][1].device.index == 0 and batch[0][1].type() == 'torch.cuda.FloatTensor' + + # tensor dict + batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}] + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert batch[0]['a'].device.index == 0 and batch[0]['a'].type() == 'torch.cuda.FloatTensor' + assert batch[0]['b'].device.index == 0 and batch[0]['b'].type() == 'torch.cuda.FloatTensor' + + # tuple of tensor list and list of tensor dict + batch = ([torch.rand(2, 3) for _ in range(2)], [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)} for _ in range(2)]) + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor' + + assert batch[1][0]['a'].device.index == 0 + assert batch[1][0]['a'].type() == 'torch.cuda.FloatTensor' + + assert batch[1][0]['b'].device.index == 0 + assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor' + + # namedtuple of tensor + BatchType = namedtuple('BatchType', ['a', 'b']) + batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + assert batch[0].a.device.index == 0 + assert batch[0].a.type() == 'torch.cuda.FloatTensor' + + # non-Tensor that has `.to()` defined + class CustomBatchType: + + def __init__(self): + self.a = torch.rand(2, 2) + + def to(self, *args, **kwargs): + self.a = self.a.to(*args, **kwargs) + return self + + batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device('cuda:0')) + assert batch.a.type() == 'torch.cuda.FloatTensor' + + # torchtext.data.Batch + samples = [{ + 'text': 'PyTorch Lightning is awesome!', + 'label': 0 + }, { + 'text': 'Please make it work with torchtext', + 'label': 1 + }] + + text_field = Field() + label_field = LabelField() + fields = {'text': ('text', text_field), 'label': ('label', label_field)} + + examples = [Example.fromdict(sample, fields) for sample in samples] + dataset = Dataset(examples=examples, fields=fields.values()) + + # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first + text_field.build_vocab(dataset) + label_field.build_vocab(dataset) + + batch = Batch(data=examples, dataset=dataset) + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + + assert batch.text.type() == 'torch.cuda.LongTensor' + assert batch.label.type() == 'torch.cuda.LongTensor' + + +@RunIf(min_gpus=1) +def test_non_blocking(): + """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """ + trainer = Trainer() + + batch = torch.zeros(2, 3) + with patch.object(batch, 'to', wraps=batch.to) as mocked: + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True) + + class BatchObject(object): + + def to(self, *args, **kwargs): + pass + + batch = BatchObject() + with patch.object(batch, 'to', wraps=batch.to) as mocked: + batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0')) + mocked.assert_called_with(torch.device('cuda', 0)) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py new file mode 100644 index 00000000000000..4d04911ffadc9b --- /dev/null +++ b/tests/models/test_grad_norm.py @@ -0,0 +1,116 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock +from unittest.mock import patch + +import numpy as np +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel +from tests.helpers.utils import reset_seed + + +class ModelWithManualGradTracker(BoringModel): + + def __init__(self, norm_type, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stored_grad_norms, self.norm_type = [], float(norm_type) + + # validation spoils logger's metrics with `val_loss` records + validation_step = None + val_dataloader = None + + def training_step(self, batch, batch_idx, optimizer_idx=None): + # just return a loss, no log or progress bar meta + output = self(batch) + loss = self.loss(batch, output) + return {'loss': loss} + + def on_after_backward(self): + out, norms = {}, [] + prefix = f'grad_{self.norm_type}_norm_' + for name, p in self.named_parameters(): + if p.grad is None: + continue + + # `np.linalg.norm` implementation likely uses fp64 intermediates + flat = p.grad.data.cpu().numpy().ravel() + norm = np.linalg.norm(flat, self.norm_type) + norms.append(norm) + + out[prefix + name] = round(norm, 4) + + # handle total norm + norm = np.linalg.norm(norms, self.norm_type) + out[prefix + 'total'] = round(norm, 4) + self.stored_grad_norms.append(out) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize("norm_type", [1., 1.25, 2, 3, 5, 10, 'inf']) +def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): + # rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above + + reset_seed() + + # use a custom grad tracking module and a list logger + model = ModelWithManualGradTracker(norm_type) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + track_grad_norm=norm_type, + log_every_n_steps=1, # request grad_norms every batch + ) + trainer.fit(model) + + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + logged_metrics = trainer.dev_debugger.logged_metrics + assert len(logged_metrics) == len(model.stored_grad_norms) + + # compare the logged metrics against tracked norms on `.backward` + for mod, log in zip(model.stored_grad_norms, logged_metrics): + common = mod.keys() & log.keys() + + log, mod = [log[k] for k in common], [mod[k] for k in common] + + assert np.allclose(log, mod, rtol=rtol) + + +@pytest.mark.parametrize("log_every_n_steps", [1, 2, 3]) +def test_grad_tracking_interval(tmpdir, log_every_n_steps): + """ Test that gradient norms get tracked in the right interval and that everytime the same keys get logged. """ + trainer = Trainer( + default_root_dir=tmpdir, + track_grad_norm=2, + log_every_n_steps=log_every_n_steps, + max_steps=10, + ) + + with patch.object(trainer.logger, "log_metrics") as mocked: + model = BoringModel() + trainer.fit(model) + expected = trainer.global_step // log_every_n_steps + grad_norm_dicts = [] + for _, kwargs in mocked.call_args_list: + metrics = kwargs.get("metrics", {}) + grad_norm_dict = {k: v for k, v in metrics.items() if k.startswith("grad_")} + if grad_norm_dict: + grad_norm_dicts.append(grad_norm_dict) + + assert len(grad_norm_dicts) == expected + assert all(grad_norm_dicts[0].keys() == g.keys() for g in grad_norm_dicts) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 00147ef2bc089e..1d55d4a5a63b7b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -1,24 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import PropertyMock + import pytest +import torch -import tests.base.utils as tutils -from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel, RandomDataset, BoringDataModule +from tests.helpers.runif import RunIf @pytest.mark.parametrize('max_steps', [1, 2, 3]) -def test_on_before_zero_grad_called(max_steps): +def test_on_before_zero_grad_called(tmpdir, max_steps): - class CurrentTestModel(EvalModelTemplate): + class CurrentTestModel(BoringModel): on_before_zero_grad_called = 0 def on_before_zero_grad(self, optimizer): self.on_before_zero_grad_called += 1 - model = CurrentTestModel(tutils.get_default_hparams()) + model = CurrentTestModel() trainer = Trainer( + default_root_dir=tmpdir, max_steps=max_steps, - num_sanity_val_steps=5, + max_epochs=2, ) assert 0 == model.on_before_zero_grad_called trainer.fit(model) @@ -27,3 +46,579 @@ def on_before_zero_grad(self, optimizer): model.on_before_zero_grad_called = 0 trainer.test(model) assert 0 == model.on_before_zero_grad_called + + +def test_training_epoch_end_metrics_collection(tmpdir): + """ Test that progress bar metrics also get collected at the end of an epoch. """ + num_epochs = 3 + + class CurrentModel(BoringModel): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + self.log_dict({'step_metric': torch.tensor(-1), 'shared_metric': 100}, logger=False, prog_bar=True) + return output + + def training_epoch_end(self, outputs): + epoch = self.current_epoch + # both scalar tensors and Python numbers are accepted + self.log_dict( + { + f'epoch_metric_{epoch}': torch.tensor(epoch), + 'shared_metric': 111 + }, + logger=False, + prog_bar=True, + ) + + model = CurrentModel() + trainer = Trainer( + max_epochs=num_epochs, + default_root_dir=tmpdir, + overfit_batches=2, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + metrics = trainer.progress_bar_dict + + # metrics added in training step should be unchanged by epoch end method + assert metrics['step_metric'] == -1 + # a metric shared in both methods gets overwritten by epoch_end + assert metrics['shared_metric'] == 111 + # metrics are kept after each epoch + for i in range(num_epochs): + assert metrics[f'epoch_metric_{i}'] == i + + +def test_training_epoch_end_metrics_collection_on_override(tmpdir): + """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ + + class LoggingCallback(Callback): + + def on_train_epoch_start(self, trainer, pl_module): + self.len_outputs = 0 + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.len_outputs = len(outputs[0]) + + class OverriddenModel(BoringModel): + + def on_train_epoch_start(self): + self.num_train_batches = 0 + + def training_epoch_end(self, outputs): # Overridden + return + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + class NotOverriddenModel(BoringModel): + + def on_train_epoch_start(self): + self.num_train_batches = 0 + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + overridden_model = OverriddenModel() + not_overridden_model = NotOverriddenModel() + not_overridden_model.training_epoch_end = None + + callback = LoggingCallback() + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + overfit_batches=2, + callbacks=[callback], + ) + + trainer.fit(overridden_model) + # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook + # if training_epoch_end is overridden + assert callback.len_outputs == overridden_model.num_train_batches + + trainer.fit(not_overridden_model) + # outputs from on_train_batch_end should be empty + assert callback.len_outputs == 0 + + +@RunIf(min_gpus=1) +@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +def test_apply_batch_transfer_handler(model_getter_mock): + expected_device = torch.device('cuda', 0) + + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + self.targets = data[1] + + class CurrentTestModel(BoringModel): + rank = 0 + transfer_batch_to_device_hook_rank = None + on_before_batch_transfer_hook_rank = None + on_after_batch_transfer_hook_rank = None + + def on_before_batch_transfer(self, batch, dataloader_idx): + self.on_before_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.samples += 1 + return batch + + def on_after_batch_transfer(self, batch, dataloader_idx): + assert batch.samples.device == batch.targets.device == expected_device + self.on_after_batch_transfer_hook_rank = self.rank + self.rank += 1 + batch.targets *= 2 + return batch + + def transfer_batch_to_device(self, batch, device): + self.transfer_batch_to_device_hook_rank = self.rank + self.rank += 1 + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + return batch + + model = CurrentTestModel() + batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) + + trainer = Trainer(gpus=1) + # running .fit() would require us to implement custom data loaders, we mock the model reference instead + + model_getter_mock.return_value = model + batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device) + + assert model.on_before_batch_transfer_hook_rank == 0 + assert model.transfer_batch_to_device_hook_rank == 1 + assert model.on_after_batch_transfer_hook_rank == 2 + assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device + assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) + assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) + + +@RunIf(min_gpus=2, special=True) +def test_transfer_batch_hook_ddp(tmpdir): + """ + Test custom data are properly moved to the right device using ddp + """ + + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + + def to(self, device, **kwargs): + self.samples = self.samples.to(device, **kwargs) + return self + + def collate_fn(batch): + return CustomBatch(batch) + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert batch.samples.device == self.device + assert isinstance(batch_idx, int) + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn) + + model = TestModel() + model.validation_step = None + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=0, + max_epochs=1, + weights_summary=None, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + +@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)]) +def test_on_train_batch_start_hook(max_epochs, batch_idx_): + + class CurrentModel(BoringModel): + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + if batch_idx == batch_idx_: + return -1 + + model = CurrentModel() + trainer = Trainer(max_epochs=max_epochs) + trainer.fit(model) + if batch_idx_ > len(model.val_dataloader()) - 1: + assert trainer.batch_idx == len(model.val_dataloader()) - 1 + assert trainer.global_step == len(model.val_dataloader()) * max_epochs + else: + assert trainer.batch_idx == batch_idx_ + assert trainer.global_step == (batch_idx_ + 1) * max_epochs + + +def test_trainer_model_hook_system(tmpdir): + """Test the LightningModule hook system.""" + + class HookedModel(BoringModel): + + def __init__(self): + super().__init__() + self.called = [] + + def on_after_backward(self): + self.called.append("on_after_backward") + super().on_after_backward() + + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) + + def on_epoch_start(self): + self.called.append("on_epoch_start") + super().on_epoch_start() + + def on_epoch_end(self): + self.called.append("on_epoch_end") + super().on_epoch_end() + + def on_fit_start(self): + self.called.append("on_fit_start") + super().on_fit_start() + + def on_fit_end(self): + self.called.append("on_fit_end") + super().on_fit_end() + + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) + + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) + + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) + + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) + + def on_pretrain_routine_start(self): + self.called.append("on_pretrain_routine_start") + super().on_pretrain_routine_start() + + def on_pretrain_routine_end(self): + self.called.append("on_pretrain_routine_end") + super().on_pretrain_routine_end() + + def on_train_start(self): + self.called.append("on_train_start") + super().on_train_start() + + def on_train_end(self): + self.called.append("on_train_end") + super().on_train_end() + + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) + + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) + + def on_train_epoch_start(self): + self.called.append("on_train_epoch_start") + super().on_train_epoch_start() + + def on_train_epoch_end(self, outputs): + self.called.append("on_train_epoch_end") + super().on_train_epoch_end(outputs) + + def on_validation_start(self): + self.called.append("on_validation_start") + super().on_validation_start() + + def on_validation_end(self): + self.called.append("on_validation_end") + super().on_validation_end() + + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) + + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) + + def on_validation_epoch_start(self): + self.called.append("on_validation_epoch_start") + super().on_validation_epoch_start() + + def on_validation_epoch_end(self, *args, **kwargs): + self.called.append("on_validation_epoch_end") + super().on_validation_epoch_end(*args, **kwargs) + + def on_test_start(self): + self.called.append("on_test_start") + super().on_test_start() + + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) + + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) + + def on_test_epoch_start(self): + self.called.append("on_test_epoch_start") + super().on_test_epoch_start() + + def on_test_epoch_end(self, *args, **kwargs): + self.called.append("on_test_epoch_end") + super().on_test_epoch_end(*args, **kwargs) + + def on_validation_model_eval(self): + self.called.append("on_validation_model_eval") + super().on_validation_model_eval() + + def on_validation_model_train(self): + self.called.append("on_validation_model_train") + super().on_validation_model_train() + + def on_test_model_eval(self): + self.called.append("on_test_model_eval") + super().on_test_model_eval() + + def on_test_model_train(self): + self.called.append("on_test_model_train") + super().on_test_model_train() + + def on_test_end(self): + self.called.append("on_test_end") + super().on_test_end() + + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage) + + model = HookedModel() + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + + assert model.called == [] + + trainer.fit(model) + expected = [ + 'setup_fit', + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'on_validation_model_eval', + 'on_validation_start', + 'on_epoch_start', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_epoch_end', + 'on_validation_end', + 'on_validation_model_train', + 'on_train_start', + 'on_epoch_start', + 'on_train_epoch_start', + 'on_train_batch_start', + 'on_before_zero_grad', + 'on_after_backward', + 'on_train_batch_end', + 'on_train_batch_start', + 'on_before_zero_grad', + 'on_after_backward', + 'on_train_batch_end', + 'on_train_epoch_end', + 'on_epoch_end', + 'on_validation_model_eval', + 'on_validation_start', + 'on_epoch_start', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_epoch_end', + 'on_save_checkpoint', + 'on_validation_end', + 'on_validation_model_train', + 'on_train_end', + 'on_fit_end', + 'teardown_fit', + ] + assert model.called == expected + + model = HookedModel() + + trainer.validate(model, verbose=False) + expected = [ + 'setup_validate', + 'on_validation_model_eval', + 'on_validation_start', + 'on_epoch_start', + 'on_validation_epoch_start', + 'on_validation_batch_start', + 'on_validation_batch_end', + 'on_validation_epoch_end', + 'on_epoch_end', + 'on_validation_end', + 'on_validation_model_train', + 'teardown_validate', + ] + assert model.called == expected + + model = HookedModel() + trainer.test(model, verbose=False) + + expected = [ + 'setup_test', + 'on_test_model_eval', + 'on_test_start', + 'on_epoch_start', + 'on_test_epoch_start', + 'on_test_batch_start', + 'on_test_batch_end', + 'on_test_epoch_end', + 'on_epoch_end', + 'on_test_end', + 'on_test_model_train', + 'teardown_test', + ] + assert model.called == expected + + +def test_trainer_datamodule_hook_system(tmpdir): + """Test the LightningDataModule hook system.""" + + class HookedDataModule(BoringDataModule): + def __init__(self): + super().__init__() + self.called = [] + + def prepare_data(self): + self.called.append("prepare_data") + super().prepare_data() + + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage=stage) + + def train_dataloader(self): + self.called.append("train_dataloader") + return super().train_dataloader() + + def test_dataloader(self): + self.called.append("test_dataloader") + return super().test_dataloader() + + def val_dataloader(self): + self.called.append("val_dataloader") + return super().val_dataloader() + + def predict_dataloader(self): + self.called.append("predict_dataloader") + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) + + model = BoringModel() + dm = HookedDataModule() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + reload_dataloaders_every_epoch=True, + ) + trainer.fit(model, datamodule=dm) + + expected = [ + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.validate(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_validate', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_validate' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.test(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test' + ] + assert dm.called == expected diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 14644aee6649d3..5b4a700babd1da 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -1,50 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os -import platform import shlex import subprocess import sys +from unittest.mock import patch +import numpy as np import pytest import torch +from sklearn.metrics import accuracy_score +from torch import optim +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import CPUAccelerator +from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from tests.helpers import BoringModel +from tests.helpers.advanced_models import BasicGAN +from tests.helpers.runif import RunIf -import tests.base.utils as tutils -from tests.base import EvalModelTemplate -from tests.base.models import TestGAN - -try: - from horovod.common.util import nccl_built -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - +if _HOROVOD_AVAILABLE: + import horovod + import horovod.torch as hvd # This script will run the actual test model training in parallel TEST_SCRIPT = os.path.join(os.path.dirname(__file__), 'data', 'horovod', 'train_default_model.py') -def _nccl_available(): - if not HOROVOD_AVAILABLE: - return False - - try: - return nccl_built() - except AttributeError: - # Horovod 0.19.1 nccl_built() does not yet work with Python 3.8: - # See: https://github.com/horovod/horovod/issues/1891 - return False - - def _run_horovod(trainer_options, on_gpu=False): """Execute the training script across multiple workers in parallel.""" + num_processes = trainer_options.get('gpus', 2) + # for Horovod, we interpret `gpus` to be set per worker + trainer_options.update(gpus=1 if on_gpu else None) + tutils.reset_seed() + # todo: Find why coverage breaks CI. + # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265 + # str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265 cmdline = [ - 'horovodrun', - '-np', '2', - sys.executable, TEST_SCRIPT, - '--trainer-options', shlex.quote(json.dumps(trainer_options)) + 'horovodrun', '-np', + str(num_processes), sys.executable, TEST_SCRIPT, '--trainer-options', + shlex.quote(json.dumps(trainer_options)) ] if on_gpu: cmdline += ['--on-gpu'] @@ -52,109 +63,162 @@ def _run_horovod(trainer_options, on_gpu=False): assert exit_code == 0 -@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") -@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@RunIf(skip_windows=True, horovod=True) def test_horovod_cpu(tmpdir): """Test Horovod running multi-process on CPU.""" trainer_options = dict( default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), gradient_clip_val=1.0, progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - distributed_backend='horovod' + limit_train_batches=0.4, + limit_val_batches=0.2, + accelerator='horovod', + deterministic=True, ) _run_horovod(trainer_options) -@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") -@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@RunIf(skip_windows=True, horovod=True) def test_horovod_cpu_implicit(tmpdir): """Test Horovod without specifying a backend, inferring from env set by `horovodrun`.""" trainer_options = dict( default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), gradient_clip_val=1.0, progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, + limit_train_batches=0.4, + limit_val_batches=0.2, + deterministic=True, ) _run_horovod(trainer_options) -@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") -@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) def test_horovod_multi_gpu(tmpdir): """Test Horovod with multi-GPU support.""" trainer_options = dict( default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), gradient_clip_val=1.0, progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - gpus=1, - distributed_backend='horovod' + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + ) + _run_horovod(trainer_options, on_gpu=True) + + +# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 +# Check with (tgaddair) on Horovod issues if this feature is needed +@pytest.mark.skip(reason="Horovod currently doesn't work with Apex") # todo +@RunIf(min_gpus=2, skip_windows=True, amp_apex=True, horovod_nccl=True) +def test_horovod_apex(tmpdir): + """Test Horovod with multi-GPU support using apex amp.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + amp_backend='apex', + precision=16, + ) + _run_horovod(trainer_options, on_gpu=True) + + +@RunIf(min_gpus=2, skip_windows=True, amp_native=True, horovod_nccl=True) +def test_horovod_amp(tmpdir): + """Test Horovod with multi-GPU support using native amp.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + amp_backend='native', + precision=16, ) _run_horovod(trainer_options, on_gpu=True) -@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") -@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_gather(tmpdir): + """Test Horovod with multi-GPU support using native amp.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + ) + _run_horovod(trainer_options, on_gpu=True) + + +@RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True) def test_horovod_transfer_batch_to_gpu(tmpdir): - class TestTrainingStepModel(EvalModelTemplate): + class TestTrainingStepModel(BoringModel): + def training_step(self, batch, *args, **kwargs): - x, y = batch - assert str(x.device) != 'cpu' - assert str(y.device) != 'cpu' + assert str(batch.device) != 'cpu' return super(TestTrainingStepModel, self).training_step(batch, *args, **kwargs) def validation_step(self, batch, *args, **kwargs): - x, y = batch - assert str(x.device) != 'cpu' - assert str(y.device) != 'cpu' + assert str(batch.device) != 'cpu' return super(TestTrainingStepModel, self).validation_step(batch, *args, **kwargs) - hparams = tutils.get_default_hparams() - model = TestTrainingStepModel(hparams) + model = TestTrainingStepModel() trainer_options = dict( default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, + limit_train_batches=0.4, + limit_val_batches=0.2, gpus=1, - distributed_backend='horovod' + deterministic=True, + accelerator='horovod', ) - tutils.run_model_test_without_loggers(trainer_options, model) + tpipes.run_model_test_without_loggers(trainer_options, model) -@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") -@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@RunIf(skip_windows=True, horovod=True) def test_horovod_multi_optimizer(tmpdir): - hparams = tutils.get_default_hparams() - model = TestGAN(hparams) + model = BasicGAN() - trainer_options = dict( + # fit model + trainer = Trainer( default_root_dir=str(tmpdir), progress_bar_refresh_rate=0, max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - distributed_backend='horovod' + limit_train_batches=0.4, + limit_val_batches=0.2, + deterministic=True, + accelerator='horovod', ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result == 1, 'model failed to complete' + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert len(trainer.optimizers) == 2 for i, optimizer in enumerate(trainer.optimizers): @@ -169,3 +233,146 @@ def get_optimizer_params(optimizer): assert get_model_params(model.generator) != get_model_params(model.discriminator) assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0]) assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) + + +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") +@RunIf(skip_windows=True, horovod=True) +def test_result_reduce_horovod(tmpdir): + """Make sure result logging works with Horovod. + + This test mirrors tests/core/test_results.py::_ddp_test_fn + """ + tutils.reset_seed() + tutils.set_random_master_port() + + def hvd_test_fn(): + path_here = os.path.abspath(os.path.dirname(__file__)) + path_root = os.path.abspath(os.path.join(path_here, '..', '..')) + sys.path.insert(0, os.path.abspath(path_root)) + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.training_step_called = True + + tensor = torch.tensor([1.0]) + self.log("test_tensor", tensor, sync_dist=True, sync_dist_op='sum', on_step=True, on_epoch=True) + + res = self._results + + # Check that `tensor` is summed across all ranks automatically + assert res["test_tensor"].item() == hvd.size(), \ + "Result-Log does not work properly with Horovod and Tensors" + + def training_epoch_end(self, outputs) -> None: + assert len(outputs) == 0 + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + logger=False + ) + + trainer.fit(model) + + horovod.run(hvd_test_fn, np=2) + + +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") +@RunIf(skip_windows=True, horovod=True, num_gpus=2) +def test_accuracy_metric_horovod(): + num_batches = 10 + batch_size = 16 + threshold = 0.5 + + def sk_metric(preds, target): + sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_target = target.view(-1).numpy() + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + + preds = torch.rand(num_batches, batch_size) + target = torch.randint(high=2, size=(num_batches, batch_size)) + + def _compute_batch(): + trainer = Trainer(fast_dev_run=True, accelerator='horovod', logger=False) + + assert isinstance(trainer.accelerator, CPUAccelerator) + # TODO: test that we selected the correct training_type_plugin based on horovod flags + + metric = Accuracy( + compute_on_step=True, + dist_sync_on_step=True, + dist_sync_fn=trainer.training_type_plugin.all_gather, + threshold=threshold + ) + + for i in range(hvd.rank(), num_batches, hvd.size()): + batch_result = metric(preds[i], target[i]) + if hvd.rank() == 0: + dist_preds = torch.stack([preds[i + r] for r in range(hvd.size())]) + dist_target = torch.stack([target[i + r] for r in range(hvd.size())]) + sk_batch_result = sk_metric(dist_preds, dist_target) + assert np.allclose(batch_result.numpy(), sk_batch_result) + + # check on all batches on all ranks + result = metric.compute() + assert isinstance(result, torch.Tensor) + + total_preds = torch.stack([preds[i] for i in range(num_batches)]) + total_target = torch.stack([target[i] for i in range(num_batches)]) + sk_result = sk_metric(total_preds, total_target) + + assert np.allclose(result.numpy(), sk_result) + + horovod.run(_compute_batch, np=2) + + +@RunIf(skip_windows=True, horovod=True) +def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=0.1) + optimizer2 = optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + model = TestModel() + model.training_epoch_end = None + + num_workers = 8 + init_lr = 0.1 * num_workers + + with patch('horovod.torch.size', return_value=8): + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.5, + limit_train_batches=0.2, + accelerator='horovod' + ) + results = trainer.fit(model) + assert results == 1 + + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] + + # Called ones after end of epoch with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr1 + + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr2 diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py new file mode 100644 index 00000000000000..7506fb06e35eff --- /dev/null +++ b/tests/models/test_hparams.py @@ -0,0 +1,699 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import os +import pickle +from argparse import Namespace + +import cloudpickle +import pytest +import torch +from fsspec.implementations.local import LocalFileSystem +from omegaconf import Container, OmegaConf +from torch.utils.data import DataLoader + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml +from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable +from tests.helpers import BoringModel, RandomDataset + +if _HYDRA_EXPERIMENTAL_AVAILABLE: + from hydra.experimental import compose, initialize + + +class SaveHparamsModel(BoringModel): + """ Tests that a model can take an object """ + + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + +def decorate(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +class SaveHparamsDecoratedModel(BoringModel): + """ Tests that a model can take an object """ + + @decorate + @decorate + def __init__(self, hparams, *my_args, **my_kwargs): + super().__init__() + self.save_hyperparameters(hparams) + + +# ------------------------- +# STANDARD TESTS +# ------------------------- +def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): + """ + Tests for the existence of an arg 'test_arg=14' + """ + hparam_type = type(model.hparams) + # test proper property assignments + assert model.hparams.test_arg == 14 + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14 + + # verify that model loads correctly + model2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert model2.hparams.test_arg == 14 + + assert isinstance(model2.hparams, hparam_type) + + if try_overwrite: + # verify that we can overwrite the property + model3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + assert model3.hparams.test_arg == 78 + + return raw_checkpoint_path + + +@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) +def test_namespace_hparams(tmpdir, cls): + # init model + model = cls(hparams=Namespace(test_arg=14)) + + # run standard test suite + _run_standard_hparams_test(tmpdir, model, cls) + + +@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) +def test_dict_hparams(tmpdir, cls): + # init model + model = cls(hparams={'test_arg': 14}) + + # run standard test suite + _run_standard_hparams_test(tmpdir, model, cls) + + +@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) +def test_omega_conf_hparams(tmpdir, cls): + # init model + conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)])) + model = cls(hparams=conf) + assert isinstance(model.hparams, Container) + + # run standard test suite + raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls) + model2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert isinstance(model2.hparams, Container) + + # config specific tests + assert model2.hparams.test_arg == 14 + assert model2.hparams.mylist[0] == 15.4 + + +def test_explicit_args_hparams(tmpdir): + """ + Tests that a model can take implicit args and assign + """ + + # define model + class LocalModel(BoringModel): + + def __init__(self, test_arg, test_arg2): + super().__init__() + self.save_hyperparameters('test_arg', 'test_arg2') + + model = LocalModel(test_arg=14, test_arg2=90) + + # run standard test suite + raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel) + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120) + + # config specific tests + assert model.hparams.test_arg2 == 120 + + +def test_implicit_args_hparams(tmpdir): + """ + Tests that a model can take regular args and assign + """ + + # define model + class LocalModel(BoringModel): + + def __init__(self, test_arg, test_arg2): + super().__init__() + self.save_hyperparameters() + + model = LocalModel(test_arg=14, test_arg2=90) + + # run standard test suite + raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel) + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120) + + # config specific tests + assert model.hparams.test_arg2 == 120 + + +def test_explicit_missing_args_hparams(tmpdir): + """ + Tests that a model can take regular args and assign + """ + + # define model + class LocalModel(BoringModel): + + def __init__(self, test_arg, test_arg2): + super().__init__() + self.save_hyperparameters('test_arg') + + model = LocalModel(test_arg=14, test_arg2=90) + + # test proper property assignments + assert model.hparams.test_arg == 14 + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14 + + # verify that model loads correctly + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123) + assert model.hparams.test_arg == 14 + assert 'test_arg2' not in model.hparams # test_arg2 is not registered in class init + + return raw_checkpoint_path + + +# ------------------------- +# SPECIFIC TESTS +# ------------------------- + + +def test_class_nesting(): + + class MyModule(LightningModule): + + def forward(self): + ... + + # make sure PL modules are always nn.Module + a = MyModule() + assert isinstance(a, torch.nn.Module) + + def test_outside(): + a = MyModule() + _ = a.hparams + + class A: + + def test(self): + a = MyModule() + _ = a.hparams + + def test2(self): + test_outside() + + test_outside() + A().test2() + A().test() + + +class CustomBoringModel(BoringModel): + + def __init__(self, batch_size=64): + super().__init__() + self.save_hyperparameters() + + +class SubClassBoringModel(CustomBoringModel): + any_other_loss = torch.nn.CrossEntropyLoss() + + def __init__(self, *args, subclass_arg=1200, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +class SubSubClassBoringModel(SubClassBoringModel): + pass + + +class AggSubClassBoringModel(SubClassBoringModel): + + def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +class UnconventionalArgsBoringModel(CustomBoringModel): + """ A model that has unconventional names for "self", "*args" and "**kwargs". """ + + def __init__(obj, *more_args, other_arg=300, **more_kwargs): + # intentionally named obj + super().__init__(*more_args, **more_kwargs) + obj.save_hyperparameters() + + +class DictConfSubClassBoringModel(SubClassBoringModel): + + def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +@pytest.mark.parametrize( + "cls", [ + CustomBoringModel, + SubClassBoringModel, + SubSubClassBoringModel, + AggSubClassBoringModel, + UnconventionalArgsBoringModel, + DictConfSubClassBoringModel, + ] +) +def test_collect_init_arguments(tmpdir, cls): + """ Test that the model automatically saves the arguments passed into the constructor """ + extra_args = {} + if cls is AggSubClassBoringModel: + extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) + elif cls is DictConfSubClassBoringModel: + extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything'))) + + model = cls(**extra_args) + assert model.hparams.batch_size == 64 + model = cls(batch_size=179, **extra_args) + assert model.hparams.batch_size == 179 + + if isinstance(model, SubClassBoringModel): + assert model.hparams.subclass_arg == 1200 + + if isinstance(model, AggSubClassBoringModel): + assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) + + # verify that the checkpoint saved the correct values + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) + trainer.fit(model) + + raw_checkpoint_path = _raw_checkpoint_path(trainer) + + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179 + + # verify that model loads correctly + model = cls.load_from_checkpoint(raw_checkpoint_path) + assert model.hparams.batch_size == 179 + + if isinstance(model, AggSubClassBoringModel): + assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) + + if isinstance(model, DictConfSubClassBoringModel): + assert isinstance(model.hparams.dict_conf, Container) + assert model.hparams.dict_conf['my_param'] == 'anything' + + # verify that we can overwrite whatever we want + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + assert model.hparams.batch_size == 99 + + +def _raw_checkpoint_path(trainer) -> str: + raw_checkpoint_paths = os.listdir(trainer.checkpoint_callback.dirpath) + raw_checkpoint_paths = [x for x in raw_checkpoint_paths if '.ckpt' in x] + assert raw_checkpoint_paths + raw_checkpoint_path = raw_checkpoint_paths[0] + raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path) + return raw_checkpoint_path + + +class LocalVariableModelSuperLast(BoringModel): + """ This model has the super().__init__() call at the end. """ + + def __init__(self, arg1, arg2, *args, **kwargs): + self.argument1 = arg1 # arg2 intentionally not set + arg1 = 'overwritten' # noqa: F841 + local_var = 1234 # noqa: F841 + super().__init__(*args, **kwargs) # this is intentionally here at the end + + +class LocalVariableModelSuperFirst(BoringModel): + """ This model has the _auto_collect_arguments() call at the end. """ + + def __init__(self, arg1, arg2, *args, **kwargs): + super().__init__(*args, **kwargs) + self.argument1 = arg1 # arg2 intentionally not set + arg1 = 'overwritten' # noqa: F841 + local_var = 1234 # noqa: F841 + self.save_hyperparameters() # this is intentionally here at the end + + +@pytest.mark.parametrize( + "cls", + [ + LocalVariableModelSuperFirst, + # LocalVariableModelSuperLast, + ] +) +def test_collect_init_arguments_with_local_vars(cls): + """ Tests that only the arguments are collected and not local variables. """ + model = cls(arg1=1, arg2=2) + assert 'local_var' not in model.hparams + assert model.hparams['arg1'] == 'overwritten' + assert model.hparams['arg2'] == 2 + + +# @pytest.mark.parametrize("cls,config", [ +# (SaveHparamsModel, Namespace(my_arg=42)), +# (SaveHparamsModel, dict(my_arg=42)), +# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))), +# (AssignHparamsModel, Namespace(my_arg=42)), +# (AssignHparamsModel, dict(my_arg=42)), +# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))), +# ]) +# def test_single_config_models(tmpdir, cls, config): +# """ Test that the model automatically saves the arguments passed into the constructor """ +# model = cls(config) +# +# # no matter how you do it, it should be assigned +# assert model.hparams.my_arg == 42 +# +# # verify that the checkpoint saved the correct values +# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) +# trainer.fit(model) +# +# # verify that model loads correctly +# raw_checkpoint_path = _raw_checkpoint_path(trainer) +# model = cls.load_from_checkpoint(raw_checkpoint_path) +# assert model.hparams.my_arg == 42 + + +class AnotherArgModel(BoringModel): + + def __init__(self, arg1): + super().__init__() + self.save_hyperparameters(arg1) + + +class OtherArgsModel(BoringModel): + + def __init__(self, arg1, arg2): + + super().__init__() + self.save_hyperparameters(arg1, arg2) + + +@pytest.mark.parametrize( + "cls,config", [ + (AnotherArgModel, dict(arg1=42)), + (OtherArgsModel, dict(arg1=3.14, arg2='abc')), + ] +) +def test_single_config_models_fail(tmpdir, cls, config): + """ Test fail on passing unsupported config type. """ + with pytest.raises(ValueError): + _ = cls(**config) + + +@pytest.mark.parametrize("past_key", ['module_arguments']) +def test_load_past_checkpoint(tmpdir, past_key): + model = CustomBoringModel() + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + raw_checkpoint['hparams_type'] = 'Namespace' + raw_checkpoint[past_key]['batch_size'] = -17 + del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + # save back the checkpoint + torch.save(raw_checkpoint, raw_checkpoint_path) + + # verify that model loads correctly + model2 = CustomBoringModel.load_from_checkpoint(raw_checkpoint_path) + assert model2.hparams.batch_size == -17 + + +def test_hparams_pickle(tmpdir): + ad = AttributeDict({'key1': 1, 'key2': 'abc'}) + pkl = pickle.dumps(ad) + assert ad == pickle.loads(pkl) + pkl = cloudpickle.dumps(ad) + assert ad == pickle.loads(pkl) + + +class UnpickleableArgsBoringModel(BoringModel): + """ A model that has an attribute that cannot be pickled. """ + + def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs): + super().__init__(**kwargs) + assert not is_picklable(pickle_me) + self.save_hyperparameters() + + +def test_hparams_pickle_warning(tmpdir): + model = UnpickleableArgsBoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_steps=1) + with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"): + trainer.fit(model) + assert 'pickle_me' not in model.hparams + + +def test_hparams_save_yaml(tmpdir): + hparams = dict( + batch_size=32, learning_rate=0.001, data_root='./any/path/here', nasted=dict(any_num=123, anystr='abcd') + ) + path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml') + + save_hparams_to_yaml(path_yaml, hparams) + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + + save_hparams_to_yaml(path_yaml, Namespace(**hparams)) + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + + save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + + save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) + assert load_hparams_from_yaml(path_yaml) == hparams + + +class NoArgsSubClassBoringModel(CustomBoringModel): + + def __init__(self): + super().__init__() + + +@pytest.mark.parametrize("cls", [ + BoringModel, + NoArgsSubClassBoringModel, +]) +def test_model_nohparams_train_test(tmpdir, cls): + """Test models that do not tae any argument in init.""" + + model = cls() + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + ) + + train_loader = DataLoader(RandomDataset(32, 64), batch_size=32) + trainer.fit(model, train_loader) + + test_loader = DataLoader(RandomDataset(32, 64), batch_size=32) + trainer.test(test_dataloaders=test_loader) + + +def test_model_ignores_non_exist_kwargument(tmpdir): + """Test that the model takes only valid class arguments.""" + + class LocalModel(BoringModel): + + def __init__(self, batch_size=15): + super().__init__() + self.save_hyperparameters() + + model = LocalModel() + assert model.hparams.batch_size == 15 + + # verify that the checkpoint saved the correct values + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model) + + # verify that we can overwrite whatever we want + raw_checkpoint_path = _raw_checkpoint_path(trainer) + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99) + assert 'non_exist_kwarg' not in model.hparams + + +class SuperClassPositionalArgs(BoringModel): + + def __init__(self, hparams): + super().__init__() + self._hparams = hparams # pretend BoringModel did not call self.save_hyperparameters() + + +class SubClassVarArgs(SuperClassPositionalArgs): + """ Loading this model should accept hparams and init in the super class """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def test_args(tmpdir): + """ Test for inheritance: super class takes positional arg, subclass takes varargs. """ + hparams = dict(test=1) + model = SubClassVarArgs(hparams) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model) + + raw_checkpoint_path = _raw_checkpoint_path(trainer) + with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'test'"): + SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path) + + +class RuntimeParamChangeModelSaving(BoringModel): + + def __init__(self, **kwargs): + super().__init__() + self.save_hyperparameters() + + +@pytest.mark.parametrize("cls", [RuntimeParamChangeModelSaving]) +def test_init_arg_with_runtime_change(tmpdir, cls): + """Test that we save/export only the initial hparams, no other runtime change allowed""" + model = cls(running_arg=123) + assert model.hparams.running_arg == 123 + model.hparams.running_arg = -1 + assert model.hparams.running_arg == -1 + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE) + hparams = load_hparams_from_yaml(path_yaml) + assert hparams.get('running_arg') == 123 + + +class UnsafeParamModel(BoringModel): + + def __init__(self, my_path, any_param=123): + super().__init__() + self.save_hyperparameters() + + +def test_model_with_fsspec_as_parameter(tmpdir): + model = UnsafeParamModel(LocalFileSystem(tmpdir)) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=1, + ) + trainer.fit(model) + trainer.test() + + +@pytest.mark.skipif(not _HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available") +def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): + """ + This test relies on configuration saved under tests/models/conf/config.yaml + """ + + class TestHydraModel(BoringModel): + + def __init__(self, args_0, args_1, args_2, kwarg_1=None): + self.save_hyperparameters() + assert self.hparams.args_0.log == "Something" + assert self.hparams.args_1['cfg'].log == "Something" + assert self.hparams.args_2[0].log == "Something" + assert self.hparams.kwarg_1['cfg'][0].log == "Something" + super().__init__() + + with initialize(config_path="conf"): + args_0 = compose(config_name="config") + args_1 = {"cfg": compose(config_name="config")} + args_2 = [compose(config_name="config")] + kwarg_1 = {"cfg": [compose(config_name="config")]} + model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1) + epochs = 2 + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + limit_train_batches=10, + limit_val_batches=10, + max_epochs=epochs, + logger=False, + ) + trainer.fit(model) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + + +@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3"))) +def test_ignore_args_list_hparams(tmpdir, ignore): + """ + Tests that args can be ignored in save_hyperparameters + """ + + class LocalModel(BoringModel): + + def __init__(self, arg1, arg2, arg3): + super().__init__() + self.save_hyperparameters(ignore=ignore) + + model = LocalModel(arg1=14, arg2=90, arg3=50) + + # test proper property assignments + assert model.hparams.arg1 == 14 + for arg in ignore: + assert arg not in model.hparams + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14 + + # verify that model loads correctly + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100) + assert model.hparams.arg1 == 14 + for arg in ignore: + assert arg not in model.hparams diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py new file mode 100644 index 00000000000000..5ba224ef775195 --- /dev/null +++ b/tests/models/test_onnx.py @@ -0,0 +1,154 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import onnxruntime +import pytest +import torch + +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +def test_model_saves_with_input_sample(tmpdir): + """Test that ONNX model saves with input sample and size is greater than 3 MB""" + model = BoringModel() + trainer = Trainer(fast_dev_run=True) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onnx") + input_sample = torch.randn((1, 32)) + model.to_onnx(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + + +@RunIf(min_gpus=1) +def test_model_saves_on_gpu(tmpdir): + """Test that model saves on gpu""" + model = BoringModel() + trainer = Trainer(gpus=1, fast_dev_run=True) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onnx") + input_sample = torch.randn((1, 32)) + model.to_onnx(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + + +def test_model_saves_with_example_output(tmpdir): + """Test that ONNX model saves when provided with example output""" + model = BoringModel() + trainer = Trainer(fast_dev_run=True) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onnx") + input_sample = torch.randn((1, 32)) + model.eval() + example_outputs = model.forward(input_sample) + model.to_onnx(file_path, input_sample, example_outputs=example_outputs) + assert os.path.exists(file_path) is True + + +def test_model_saves_with_example_input_array(tmpdir): + """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + file_path = os.path.join(tmpdir, "model.onnx") + model.to_onnx(file_path) + assert os.path.exists(file_path) is True + assert os.path.getsize(file_path) > 4e2 + + +@RunIf(min_gpus=2) +def test_model_saves_on_multi_gpu(tmpdir): + """Test that ONNX model saves on a distributed backend""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + accelerator='ddp_spawn', + progress_bar_refresh_rate=0, + ) + + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + tpipes.run_model_test(trainer_options, model, min_acc=0.08) + + file_path = os.path.join(tmpdir, "model.onnx") + model.to_onnx(file_path) + assert os.path.exists(file_path) is True + + +def test_verbose_param(tmpdir, capsys): + """Test that output is present when verbose parameter is set""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + file_path = os.path.join(tmpdir, "model.onnx") + model.to_onnx(file_path, verbose=True) + captured = capsys.readouterr() + assert "graph(%" in captured.out + + +def test_error_if_no_input(tmpdir): + """Test that an error is thrown when there is no input tensor""" + model = BoringModel() + model.example_input_array = None + file_path = os.path.join(tmpdir, "model.onnx") + with pytest.raises( + ValueError, + match=r'Could not export to ONNX since neither `input_sample` nor' + r' `model.example_input_array` attribute is set.' + ): + model.to_onnx(file_path) + + +def test_if_inference_output_is_valid(tmpdir): + """Test that the output inferred from ONNX model is same as from PyTorch""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + trainer = Trainer(fast_dev_run=True) + trainer.fit(model) + + model.eval() + with torch.no_grad(): + torch_out = model(model.example_input_array) + + file_path = os.path.join(tmpdir, "model.onnx") + model.to_onnx(file_path, model.example_input_array, export_params=True) + + ort_session = onnxruntime.InferenceSession(file_path) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)} + ort_outs = ort_session.run(None, ort_inputs) + + # compare ONNX Runtime and PyTorch results + assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 0a927a3a94e0ad..7a43b2d0832f95 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -1,25 +1,271 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import glob import logging as log import os +import pickle +from copy import deepcopy +from typing import Generic, TypeVar +import cloudpickle import pytest import torch +import torch.nn.functional as F -import tests.base.utils as tutils -from pytorch_lightning import Trainer +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils +from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate +from pytorch_lightning.trainer.states import RunningStage, TrainerState +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel -@pytest.mark.spawn -@pytest.mark.parametrize("backend", ['dp', 'ddp']) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_running_test_pretrained_model_distrib(tmpdir, backend): +class ModelTrainerPropertyParity(Callback): + + def _check_properties(self, trainer, pl_module): + assert trainer.global_step == pl_module.global_step + assert trainer.current_epoch == pl_module.current_epoch + + def on_train_start(self, trainer, pl_module): + self._check_properties(trainer, pl_module) + + def on_train_batch_start(self, trainer, pl_module, *args, **kwargs): + self._check_properties(trainer, pl_module) + + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): + self._check_properties(trainer, pl_module) + + def on_epoch_end(self, trainer, pl_module): + self._check_properties(trainer, pl_module) + + def on_train_end(self, trainer, pl_module): + self._check_properties(trainer, pl_module) + + +class ValTestLossBoringModel(BoringModel): + + def __init__(self, batch_size=4): + super().__init__() + self.save_hyperparameters() + + def validation_step(self, batch, batch_idx): + out = super().validation_step(batch, batch_idx) + self.log('val_loss', out['x']) + return out + + def test_step(self, batch, batch_idx): + out = super().test_step(batch, batch_idx) + self.log('test_loss', out['y']) + return out + + +T = TypeVar('T') + + +class GenericParentValTestLossBoringModel(Generic[T], ValTestLossBoringModel): + + def __init__(self, batch_size: int = 4): + super().__init__(batch_size=batch_size) + + +class GenericValTestLossBoringModel(GenericParentValTestLossBoringModel[int]): + pass + + +class CustomClassificationModelDP(ClassificationModel): + + def _step(self, batch, batch_idx): + x, y = batch + logits = self(x) + return {'logits': logits, 'y': y} + + def training_step(self, batch, batch_idx): + out = self._step(batch, batch_idx) + loss = F.cross_entropy(out['logits'], out['y']) + return loss + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def validation_step_end(self, outputs): + self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y'])) + + +def test_model_properties_resume_from_checkpoint(tmpdir): + """ + Test that properties like `current_epoch` and `global_step` + in model and trainer are always the same. + """ + model = BoringModel() + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + trainer_args = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + logger=False, + callbacks=[checkpoint_callback, ModelTrainerPropertyParity()], # this performs the assertions + ) + trainer = Trainer(**trainer_args) + trainer.fit(model) + + trainer_args.update(max_epochs=2) + trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) + trainer.fit(model) + + +def test_try_resume_from_non_existing_checkpoint(tmpdir): + """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" + dm = ClassifDataModule() + model = ClassificationModel() + checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=False, + callbacks=[checkpoint_cb], + limit_train_batches=2, + limit_val_batches=2, + ) + # Generate checkpoint `last.ckpt` with BoringModel + trainer.fit(model, datamodule=dm) + # `True` if resume/restore successfully else `False` + assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu) + assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu) + + +class CaptureCallbacksBeforeTraining(Callback): + callbacks = [] + + def on_train_start(self, trainer, pl_module): + self.callbacks = deepcopy(trainer.callbacks) + + +def test_callbacks_state_resume_from_checkpoint(tmpdir): + """ Test that resuming from a checkpoint restores callbacks that persist state. """ + dm = ClassifDataModule() + model = ClassificationModel() + callback_capture = CaptureCallbacksBeforeTraining() + + def get_trainer_args(): + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + trainer_args = dict( + default_root_dir=tmpdir, max_steps=1, logger=False, callbacks=[checkpoint, callback_capture] + ) + assert checkpoint.best_model_path == "" + assert checkpoint.best_model_score is None + return trainer_args + + # initial training + trainer = Trainer(**get_trainer_args()) + trainer.fit(model, datamodule=dm) + callbacks_before_resume = deepcopy(trainer.callbacks) + + # resumed training + trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) + trainer.fit(model, datamodule=dm) + + assert len(callbacks_before_resume) == len(callback_capture.callbacks) + + for before, after in zip(callbacks_before_resume, callback_capture.callbacks): + if isinstance(before, ModelCheckpoint): + assert before.best_model_path == after.best_model_path + assert before.best_model_score == after.best_model_score + + +def test_callbacks_references_resume_from_checkpoint(tmpdir): + """ Test that resuming from a checkpoint sets references as expected. """ + dm = ClassifDataModule() + model = ClassificationModel() + args = {'default_root_dir': tmpdir, 'max_steps': 1, 'logger': False} + + # initial training + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + trainer = Trainer(**args, callbacks=[checkpoint]) + assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback + trainer.fit(model, datamodule=dm) + + # resumed training + new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + # pass in a new checkpoint object, which should take + # precedence over the one in the last.ckpt file + trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt")) + assert checkpoint is not new_checkpoint + assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback + trainer.fit(model, datamodule=dm) + + +@RunIf(min_gpus=2) +def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" + tutils.set_random_master_port() - model = EvalModelTemplate(tutils.get_default_hparams()) + dm = ClassifDataModule() + model = CustomClassificationModelDP(lr=0.1) + + # exp file to get meta + logger = tutils.get_default_logger(tmpdir) + + # exp file to get weights + checkpoint = tutils.init_checkpoint_callback(logger) + + trainer_options = dict( + progress_bar_refresh_rate=0, + max_epochs=2, + limit_train_batches=5, + limit_val_batches=5, + callbacks=[checkpoint], + logger=logger, + gpus=[0, 1], + accelerator='dp', + default_root_dir=tmpdir, + ) + + # fit model + trainer = Trainer(**trainer_options) + trainer.fit(model, datamodule=dm) + + # correct result and ok accuracy + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # run test set + new_trainer = Trainer(**trainer_options) + new_trainer.test(pretrained_model) + pretrained_model.cpu() + + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: + tpipes.run_prediction_eval_model_template(pretrained_model, dataloader) + + +@RunIf(min_gpus=2) +def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): + """Verify `test()` on pretrained model.""" + tutils.set_random_master_port() + dm = ClassifDataModule() + model = ClassificationModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) @@ -30,44 +276,43 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend): trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, - train_percent_check=0.4, - val_percent_check=0.2, - checkpoint_callback=checkpoint, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[checkpoint], logger=logger, gpus=[0, 1], - distributed_backend=backend, + accelerator='ddp_spawn', + default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) - result = trainer.fit(model) + trainer.fit(model, datamodule=dm) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy - assert result == 1, 'training failed to complete' - pretrained_model = tutils.load_model(logger, - trainer.checkpoint_callback.dirpath, - module_class=EvalModelTemplate) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) + pretrained_model.cpu() - # test we have good test accuracy - tutils.assert_ok_model_acc(new_trainer) - - dataloaders = model.test_dataloader() + dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: - tutils.run_prediction(dataloader, pretrained_model) + tpipes.run_prediction_eval_model_template(pretrained_model, dataloader, min_acc=0.1) def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + tutils.reset_seed() + dm = ClassifDataModule() + model = ClassificationModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -77,82 +322,84 @@ def test_running_test_pretrained_model_cpu(tmpdir): trainer_options = dict( progress_bar_refresh_rate=0, - max_epochs=4, - train_percent_check=0.4, - val_percent_check=0.2, - checkpoint_callback=checkpoint, - logger=logger + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + callbacks=[checkpoint], + logger=logger, + default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) - result = trainer.fit(model) + trainer.fit(model, datamodule=dm) # correct result and ok accuracy - assert result == 1, 'training failed to complete' - pretrained_model = tutils.load_model( - logger, trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate - ) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) - new_trainer.test(pretrained_model) + new_trainer.test(pretrained_model, datamodule=dm) # test we have good test accuracy - tutils.assert_ok_model_acc(new_trainer) + tutils.assert_ok_model_acc(new_trainer, key='test_acc', thr=0.45) -def test_load_model_from_checkpoint(tmpdir): +@pytest.mark.parametrize('model_template', [ValTestLossBoringModel, GenericValTestLossBoringModel]) +def test_load_model_from_checkpoint(tmpdir, model_template): """Verify test() on pretrained model.""" - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) + tutils.reset_seed() + model = model_template() trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, - train_percent_check=0.4, - val_percent_check=0.2, - checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='val_loss', save_top_k=-1)], default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) - result = trainer.fit(model) - trainer.test() + trainer.fit(model) + trainer.test(ckpt_path=None) # correct result and ok accuracy - assert result == 1, 'training failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # load last checkpoint last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] - pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint) + + # Since `BoringModel` has `_save_hparams = True` by default, check that ckpt has hparams + ckpt = torch.load(last_checkpoint) + assert model_template.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), 'hyper_parameters missing from checkpoints' + + # Ensure that model can be correctly restored from checkpoint + pretrained_model = model_template.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly - for k, v in vars(hparams).items(): + for k, v in model.hparams.items(): assert getattr(pretrained_model.hparams, k) == v # assert weights are the same for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()): assert torch.all(torch.eq(old_p, new_p)), 'loaded weights are not the same as the saved weights' + # Check `test` on pretrained model: new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - # test we have good test accuracy - tutils.assert_ok_model_acc(new_trainer) - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@RunIf(min_gpus=2) def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) + model = CustomClassificationModelDP(lr=0.1) + dm = ClassifDataModule() - trainer_options = dict( - max_epochs=1, - gpus=2, - distributed_backend='dp', - ) + trainer_options = dict(max_epochs=1, gpus=2, accelerator='dp', default_root_dir=tmpdir) # get logger logger = tutils.get_default_logger(tmpdir) @@ -163,52 +410,58 @@ def test_dp_resume(tmpdir): # add these to the trainer options trainer_options['logger'] = logger - trainer_options['checkpoint_callback'] = checkpoint + trainer_options['callbacks'] = [checkpoint] # fit model trainer = Trainer(**trainer_options) trainer.is_slurm_managing_tasks = True - result = trainer.fit(model) + trainer.fit(model, datamodule=dm) # track epoch before saving. Increment since we finished the current epoch, don't want to rerun real_global_epoch = trainer.current_epoch + 1 # correct result and ok accuracy - assert result == 1, 'amp + dp model failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # --------------------------- # HPC LOAD/SAVE # --------------------------- # save - trainer.hpc_save(tmpdir, logger) + trainer.checkpoint_connector.hpc_save(tmpdir, logger) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options['logger'] = new_logger - trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir) - trainer_options['train_percent_check'] = 0.5 - trainer_options['val_percent_check'] = 0.2 + trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)] + trainer_options['limit_train_batches'] = 0.5 + trainer_options['limit_val_batches'] = 0.2 trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) - # set the epoch start hook so we can predict before the model does the full training - def assert_good_acc(): - assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0 + class CustomModel(CustomClassificationModelDP): - # if model and state loaded correctly, predictions will be good even though we - # haven't trained with the new loaded model - dp_model = new_trainer.model - dp_model.eval() + def __init__(self): + super().__init__() + self.on_train_start_called = False + + # set the epoch start hook so we can predict before the model does the full training + def on_train_start(self): + assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 + + # if model and state loaded correctly, predictions will be good even though we + # haven't trained with the new loaded model + new_trainer._running_stage = RunningStage.VALIDATING - dataloader = trainer.train_dataloader - tutils.run_prediction(dataloader, dp_model, dp=True) + dataloader = self.train_dataloader() + tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) + self.on_train_start_called = True # new model - model = EvalModelTemplate(hparams) - model.on_train_start = assert_good_acc + model = CustomModel() # fit new model which should load hpc weights - new_trainer.fit(model) + new_trainer.fit(model, datamodule=dm) + assert model.on_train_start_called # test freeze on gpu model.freeze() @@ -217,97 +470,160 @@ def assert_good_acc(): def test_model_saving_loading(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = BoringModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) - trainer_options = dict( + # fit model + trainer = Trainer( max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + default_root_dir=tmpdir, ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) + trainer.fit(model) # traning complete - assert result == 1, 'amp + ddp model failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] - for dataloader in dataloaders: - for batch in dataloader: - break - - x, y = batch - x = x.view(x.size(0), -1) + batch = next(iter(dataloaders[0])) # generate preds before saving model model.eval() - pred_before_saving = model(x) + pred_before_saving = model(batch) # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model - tags_path = tutils.get_data_path(logger, path_dir=tmpdir) - tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = EvalModelTemplate.load_from_checkpoint( + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, 'hparams.yaml') + model_2 = BoringModel.load_from_checkpoint( checkpoint_path=new_weights_path, - tags_csv=tags_path + hparams_file=hparams_path, ) model_2.eval() # make prediction # assert that both predictions are the same - new_pred = model_2(x) + new_pred = model_2(batch) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 -def test_load_model_with_missing_hparams(tmpdir): - trainer_options = dict( - progress_bar_refresh_rate=0, - max_epochs=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), - logger=False, +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): + """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) + + model = BoringModel() + # Extra layer + model.c_d3 = torch.nn.Linear(32, 32) + + # logger file to get meta + logger = tutils.get_default_logger(tmpdir) + + # fit model + trainer = Trainer( default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + logger=logger, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) + trainer.fit(model) + + # traning complete + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + # save model + new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') + trainer.save_checkpoint(new_weights_path) + + # load new model + hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml') + hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' + ckpt_path = hparams_url if url_ckpt else new_weights_path + + BoringModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + hparams_file=hparams_path, + strict=False, + ) + + with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): + BoringModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + hparams_file=hparams_path, + strict=True, + ) + + +@pytest.mark.parametrize('url_ckpt', [True, False]) +def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): + """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv('TORCH_HOME', tmpdir) + + model = BoringModel() + + # logger file to get meta + logger = tutils.get_default_logger(tmpdir) # fit model - trainer = Trainer(**trainer_options) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + logger=logger, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + ) + trainer.fit(model) + + # traning complete + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + # save model + new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') + trainer.save_checkpoint(new_weights_path) + + # load new model + hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml') + hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' + ckpt_path = hparams_url if url_ckpt else new_weights_path + + class CurrentModel(BoringModel): - class CurrentModelWithoutHparams(EvalModelTemplate): def __init__(self): - hparams = tutils.get_default_hparams() - super().__init__(hparams) + super().__init__() + self.c_d3 = torch.nn.Linear(7, 7) - class CurrentModelUnusedHparams(EvalModelTemplate): - def __init__(self, hparams): - hparams = tutils.get_default_hparams() - super().__init__(hparams) + CurrentModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + hparams_file=hparams_path, + strict=False, + ) - model = CurrentModelWithoutHparams() - trainer.fit(model) - last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] + with pytest.raises(RuntimeError, match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): + CurrentModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + hparams_file=hparams_path, + strict=True, + ) - # try to load a checkpoint that has hparams but model is missing hparams arg - with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"): - CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint) - # create a checkpoint without hyperparameters - # if the model does not take a hparams argument, it should not throw an error - ckpt = torch.load(last_checkpoint) - del(ckpt['hparams']) - torch.save(ckpt, last_checkpoint) - CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint) - - # load checkpoint without hparams again - # warn if user's model has hparams argument - with pytest.warns(UserWarning, match=r".*Will pass in an empty Namespace instead."): - CurrentModelUnusedHparams.load_from_checkpoint(last_checkpoint) +def test_model_pickle(tmpdir): + model = BoringModel() + pickle.dumps(model) + cloudpickle.dumps(model) diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py new file mode 100644 index 00000000000000..5750bb66a75b63 --- /dev/null +++ b/tests/models/test_sync_batchnorm.py @@ -0,0 +1,130 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.plugins.environments import LightningEnvironment +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import FLOAT16_EPSILON +from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf +from tests.helpers.utils import set_random_master_port + + +class SyncBNModule(LightningModule): + + def __init__(self, gpu_count=1, **kwargs): + super().__init__() + + self.gpu_count = gpu_count + self.bn_targets = None + if 'bn_targets' in kwargs: + self.bn_targets = kwargs['bn_targets'] + + self.linear = nn.Linear(28 * 28, 10) + self.bn_layer = nn.BatchNorm1d(28 * 28) + + def forward(self, x, batch_idx): + with torch.no_grad(): + out_bn = self.bn_layer(x.view(x.size(0), -1)) + + if self.bn_targets: + bn_target = self.bn_targets[batch_idx] + + # executes on both GPUs + bn_target = bn_target[self.trainer.local_rank::self.gpu_count] + bn_target = bn_target.to(out_bn.device) + assert torch.sum(torch.abs(bn_target - out_bn)) < FLOAT16_EPSILON + + out = self.linear(out_bn) + + return out, out_bn + + def training_step(self, batch, batch_idx): + x, y = batch + + y_hat, _ = self(x, batch_idx) + loss = F.cross_entropy(y_hat, y) + + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.linear.parameters(), lr=0.02) + + +# TODO: Fatal Python error: Bus error +@pytest.mark.skip(reason="Fatal Python error: Bus error") +@RunIf(min_gpus=2, special=True) +def test_sync_batchnorm_ddp(tmpdir): + seed_everything(234) + set_random_master_port() + + # define datamodule and dataloader + dm = MNISTDataModule() + dm.prepare_data() + dm.setup(stage=None) + + train_dataloader = dm.train_dataloader() + model = SyncBNModule() + + bn_outputs = [] + + # shuffle is false by default + for batch_idx, batch in enumerate(train_dataloader): + x, _ = batch + + _, out_bn = model.forward(x, batch_idx) + bn_outputs.append(out_bn) + + # get 3 steps + if batch_idx == 2: + break + + bn_outputs = [x.cuda() for x in bn_outputs] + + # reset datamodule + # batch-size = 16 because 2 GPUs in DDP + dm = MNISTDataModule(batch_size=16, dist_sampler=True) + dm.prepare_data() + dm.setup(stage=None) + + model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs) + ddp = DDPSpawnPlugin( + parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], + num_nodes=1, + sync_batchnorm=True, + cluster_environment=LightningEnvironment(), + find_unused_parameters=True + ) + + trainer = Trainer( + default_root_dir=tmpdir, + gpus=2, + num_nodes=1, + accelerator='ddp_spawn', + max_epochs=1, + max_steps=3, + sync_batchnorm=True, + num_sanity_val_steps=0, + replace_sampler_ddp=False, + plugins=[ddp] + ) + + trainer.fit(model, dm) + assert trainer.state == TrainerState.FINISHED, "Sync batchnorm failing with DDP" diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py new file mode 100644 index 00000000000000..b03ed0806d8001 --- /dev/null +++ b/tests/models/test_torchscript.py @@ -0,0 +1,157 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from tests.helpers import BoringModel +from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN +from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("modelclass", [ + BoringModel, + ParityModuleRNN, + BasicGAN, +]) +def test_torchscript_input_output(modelclass): + """ Test that scripted LightningModule forward works. """ + model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(model.example_input_array) + + script_output = script(model.example_input_array) + assert torch.allclose(script_output, model_output) + + +@pytest.mark.parametrize("modelclass", [ + BoringModel, + ParityModuleRNN, + BasicGAN, +]) +def test_torchscript_example_input_output_trace(modelclass): + """ Test that traced LightningModule forward works with example_input_array """ + model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + + script = model.to_torchscript(method='trace') + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(model.example_input_array) + + script_output = script(model.example_input_array) + assert torch.allclose(script_output, model_output) + + +def test_torchscript_input_output_trace(): + """ Test that traced LightningModule forward works with example_inputs """ + model = BoringModel() + example_inputs = torch.randn(1, 32) + script = model.to_torchscript(example_inputs=example_inputs, method='trace') + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(example_inputs) + + script_output = script(example_inputs) + assert torch.allclose(script_output, model_output) + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda", 0)]) +def test_torchscript_device(device): + """ Test that scripted module is on the correct device. """ + model = BoringModel().to(device) + model.example_input_array = torch.randn(5, 32) + + script = model.to_torchscript() + assert next(script.parameters()).device == device + script_output = script(model.example_input_array.to(device)) + assert script_output.device == device + + +def test_torchscript_retain_training_state(): + """ Test that torchscript export does not alter the training mode of original model. """ + model = BoringModel() + model.train(True) + script = model.to_torchscript() + assert model.training + assert not script.training + model.train(False) + _ = model.to_torchscript() + assert not model.training + assert not script.training + + +@pytest.mark.parametrize("modelclass", [ + BoringModel, + ParityModuleRNN, + BasicGAN, +]) +def test_torchscript_properties(tmpdir, modelclass): + """ Test that scripted LightningModule has unnecessary methods removed. """ + model = modelclass() + model.datamodule = MNISTDataModule(tmpdir) + script = model.to_torchscript() + assert not hasattr(script, "datamodule") + assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") + assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") + assert not callable(getattr(script, "training_step", None)) + + +@pytest.mark.parametrize("modelclass", [ + BoringModel, + ParityModuleRNN, + BasicGAN, +]) +@RunIf(min_torch="1.5.0") +def test_torchscript_save_load(tmpdir, modelclass): + """ Test that scripted LightningModule is correctly saved and can be loaded. """ + model = modelclass() + output_file = str(tmpdir / "model.pt") + script = model.to_torchscript(file_path=output_file) + loaded_script = torch.jit.load(output_file) + assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) + + +def test_torchcript_invalid_method(tmpdir): + """Test that an error is thrown with invalid torchscript method""" + model = BoringModel() + model.train(True) + + with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): + model.to_torchscript(method='temp') + + +def test_torchscript_with_no_input(tmpdir): + """Test that an error is thrown when there is no input tensor""" + model = BoringModel() + model.example_input_array = None + + with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'): + model.to_torchscript(method='trace') diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py new file mode 100644 index 00000000000000..b2ed0db87d8d54 --- /dev/null +++ b/tests/models/test_tpu.py @@ -0,0 +1,398 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from argparse import ArgumentParser +from unittest import mock + +import pytest +from torch.utils.data import DataLoader + +import tests.helpers.pipelines as tpipes +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import TPUAccelerator +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.plugins import TPUSpawnPlugin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf +from tests.helpers.utils import pl_multi_process_test + +if _TPU_AVAILABLE: + import torch_xla + import torch_xla.distributed.xla_multiprocessing as xmp + SERIAL_EXEC = xmp.MpSerialExecutor() + +_LARGER_DATASET = RandomDataset(32, 2000) + + +# 8 cores needs a big dataset +def _serial_train_loader(): + return DataLoader(_LARGER_DATASET, batch_size=32) + + +class SerialLoaderBoringModel(BoringModel): + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 2000), batch_size=32) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 2000), batch_size=32) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_tpu_cores_1(tmpdir): + """Make sure model trains on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=2, + tpu_cores=1, + limit_train_batches=4, + limit_val_batches=4, + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + +@pytest.mark.parametrize('tpu_core', [1, 5]) +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_tpu_index(tmpdir, tpu_core): + """Make sure model trains on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=2, + tpu_cores=[tpu_core], + limit_train_batches=4, + limit_val_batches=4, + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}' + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_tpu_cores_8(tmpdir): + """Make sure model trains on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=8, + limit_train_batches=4, + limit_val_batches=4, + ) + + # 8 cores needs a big dataset + model = SerialLoaderBoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_16bit_tpu_cores_1(tmpdir): + """Make sure model trains on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + precision=16, + progress_bar_refresh_rate=0, + max_epochs=2, + tpu_cores=1, + limit_train_batches=8, + limit_val_batches=2, + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False) + assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" + + +@pytest.mark.parametrize('tpu_core', [1, 5]) +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_16bit_tpu_index(tmpdir, tpu_core): + """Make sure model trains on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + precision=16, + progress_bar_refresh_rate=0, + max_epochs=2, + tpu_cores=[tpu_core], + limit_train_batches=4, + limit_val_batches=2, + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False) + assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}' + assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_16bit_tpu_cores_8(tmpdir): + """Make sure model trains on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + precision=16, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=8, + limit_train_batches=4, + limit_val_batches=4, + ) + + # 8 cores needs a big dataset + model = SerialLoaderBoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_model_tpu_early_stop(tmpdir): + """Test if single TPU core training works""" + + class CustomBoringModel(BoringModel): + + def validation_step(self, *args, **kwargs): + out = super().validation_step(*args, **kwargs) + self.log('val_loss', out['x']) + return out + + tutils.reset_seed() + model = CustomBoringModel() + trainer = Trainer( + callbacks=[EarlyStopping(monitor='val_loss')], + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + tpu_cores=8, + ) + trainer.fit(model) + trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32)) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_grad_norm(tmpdir): + """Test if grad_norm works on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=4, + tpu_cores=1, + limit_train_batches=4, + limit_val_batches=4, + gradient_clip_val=0.5, + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_dataloaders_passed_to_fit(tmpdir): + """Test if dataloaders passed to trainer works on TPU""" + tutils.reset_seed() + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + tpu_cores=8, + ) + trainer.fit( + model, + train_dataloader=model.train_dataloader(), + val_dataloaders=model.val_dataloader(), + ) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@pytest.mark.parametrize( + ['tpu_cores', 'expected_tpu_id'], + [pytest.param(1, None), pytest.param(8, None), + pytest.param([1], 1), pytest.param([8], 8)], +) +@RunIf(tpu=True) +def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id): + """Test if trainer.tpu_id is set as expected""" + assert Trainer(tpu_cores=tpu_cores).accelerator_connector.tpu_id == expected_tpu_id + + +def test_tpu_misconfiguration(): + """Test if trainer.tpu_id is set as expected""" + with pytest.raises(MisconfigurationException, match="`tpu_cores` can only be"): + Trainer(tpu_cores=[1, 8]) + + +@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU") +def test_exception_when_no_tpu_found(tmpdir): + """Test if exception is thrown when xla devices are not available""" + + with pytest.raises(MisconfigurationException, match='No TPU devices were found.'): + Trainer(tpu_cores=8) + + +@pytest.mark.parametrize('tpu_cores', [1, 8, [1]]) +@RunIf(tpu=True) +def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): + """Test if distributed_backend is set to `tpu` when tpu_cores is not None""" + assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu" + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_broadcast_on_tpu(): + """ Checks if an object from the master process is broadcasted to other processes correctly""" + + def test_broadcast(rank): + trainer = Trainer(tpu_cores=8) + assert isinstance(trainer.accelerator, TPUAccelerator) + assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) + obj = ("ver_0.5", "logger_name", rank) + result = trainer.training_type_plugin.broadcast(obj) + assert result == ("ver_0.5", "logger_name", 0) + + xmp.spawn(test_broadcast, nprocs=8, start_method='fork') + + +@pytest.mark.parametrize( + ["tpu_cores", "expected_tpu_id", "error_expected"], + [ + pytest.param(1, None, False), + pytest.param(8, None, False), + pytest.param([1], 1, False), + pytest.param([8], 8, False), + pytest.param("1,", 1, False), + pytest.param("1", None, False), + pytest.param("9, ", 9, True), + pytest.param([9], 9, True), + pytest.param([0], 0, True), + pytest.param(2, None, True), + pytest.param(10, None, True), + ], +) +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): + if error_expected: + with pytest.raises(MisconfigurationException, match=r".*tpu_cores` can only be 1, 8 or [<1-8>]*"): + Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) + else: + trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) + assert trainer.accelerator_connector.tpu_id == expected_tpu_id + + +@pytest.mark.parametrize( + ['cli_args', 'expected'], + [pytest.param('--tpu_cores=8', {'tpu_cores': 8}), + pytest.param("--tpu_cores=1,", {'tpu_cores': '1,'})] +) +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_cores_with_argparse(cli_args, expected): + """Test passing tpu_cores in command line""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + for k, v in expected.items(): + assert getattr(args, k) == v + assert Trainer.from_argparse_args(args) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_reduce(): + """Test tpu spawn reduce operation """ + + def test_reduce(rank): + trainer = Trainer(tpu_cores=8) + # faster this way + reduce_ops = ["mean", "AVG", "undefined", "sum", ReduceOp.SUM, ReduceOp.MAX] + for reduce_op in reduce_ops: + if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: + with pytest.raises(MisconfigurationException, match="TPUSpawn TrainingTypePlugin only support"): + result = trainer.training_type_plugin.reduce(1, reduce_op) + else: + result = trainer.training_type_plugin.reduce(1, reduce_op) + if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): + assert result.item() == 1 + else: + assert result.item() == 8 + + xmp.spawn(test_reduce, nprocs=8, start_method='fork') + + +@RunIf(tpu=True) +@pl_multi_process_test +@pytest.mark.parametrize("clip_val", [10]) +@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_") +def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + TODO: Fix (test fails with parametrize) + """ + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=1, + precision=16, + limit_train_batches=4, + limit_val_batches=4, + gradient_clip_val=clip_val, + ) + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + if clip_val > 0: + mock_clip_grad_norm.assert_called() + else: + mock_clip_grad_norm.assert_not_called() + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_test_works_with_checkpoint_false(tmpdir): + """Ensure that model trains properly when `checkpoint_callback` is set to False.""" + + # Train a model on TPU + model = BoringModel() + trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/overrides/__init__.py b/tests/overrides/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py new file mode 100644 index 00000000000000..aaf47c82d5f087 --- /dev/null +++ b/tests/overrides/test_data_parallel.py @@ -0,0 +1,125 @@ +from unittest.mock import MagicMock, Mock + +import pytest +import torch +from torch.nn import DataParallel + +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.data_parallel import ( + LightningParallelModule, + python_scalar_to_tensor, + unsqueeze_scalar_tensor, +) +from pytorch_lightning.trainer.states import RunningStage +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("wrapper_class", [ + LightningParallelModule, + LightningDistributedModule, +]) +@pytest.mark.parametrize( + "stage", [ + ("training", "training_step"), + ("testing", "test_step"), + ("validating", "validation_step"), + ("predicting", "predict_step"), + ] +) +def test_lightning_wrapper_module_methods(wrapper_class, stage): + """ Test that the LightningWrapper redirects .forward() to the LightningModule methods. """ + pl_module = MagicMock() + wrapped_module = wrapper_class(pl_module) + + batch = torch.rand(5) + batch_idx = 3 + + prop, step = stage + pl_module.trainer.sanity_checking = False + + for p in ("training", "testing", "validating", "predicting"): + setattr(pl_module.trainer, p, p == prop) + + wrapped_module(batch, batch_idx) + getattr(pl_module, step).assert_called_with(batch, batch_idx) + + +@pytest.mark.parametrize( + "inp,expected", [ + [torch.tensor(1.0), torch.tensor([1.0])], + [torch.tensor([2.0]), torch.tensor([2.0])], + [torch.ones(3, 4, 5), torch.ones(3, 4, 5)], + ] +) +def test_unsqueeze_scalar_tensor(inp, expected): + """ Test that the utility function unsqueezes only scalar tensors. """ + assert torch.all(unsqueeze_scalar_tensor(inp).eq(expected)) + + +@RunIf(min_gpus=2) +def test_lightning_parallel_module_unsqueeze_scalar(): + """ Test that LightningParallelModule takes care of un-squeezeing 0-dim tensors. """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + loss = output["loss"] + loss = loss.squeeze() + assert loss.dim() == 0 + # PyTorch usually warns about 0-dim tensors returned in DP + return {"loss": loss} + + model = TestModel() + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING + batch = torch.rand(2, 32).cuda() + batch_idx = 0 + + wrapped_model = LightningParallelModule(model).cuda() + dp_module = DataParallel(wrapped_model, device_ids=[0, 1]) + + output = wrapped_model(batch, batch_idx) + assert output["loss"].dim() == 1 + + with pytest.warns(None) as record: + output = dp_module(batch, batch_idx) + + assert output["loss"].dim() == 1 + assert not record + + +@pytest.mark.parametrize( + "inp,expected", [ + [1.0, torch.tensor([1.0])], + [2, torch.tensor([2.0])], + [True, torch.tensor([True])], + ] +) +def test_python_scalar_to_tensor(inp, expected): + assert torch.all(python_scalar_to_tensor(inp).eq(expected)) + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda", 0)]) +def test_lightning_parallel_module_python_scalar_conversion(device): + """ Test that LightningParallelModule can convert Python scalars to tensors. """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + # PyTorch DP does not support Python scalars, Lightning converts them to tensors + output.update({"python scalar": 12.3}) + return output + + model = TestModel().to(device) + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING + batch = torch.rand(2, 32).to(device) + batch_idx = 0 + + wrapped_model = LightningParallelModule(model) + output = wrapped_model(batch, batch_idx) + assert output["python scalar"] == torch.tensor([12.3], device=device) diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/plugins/environments/__init__.py b/tests/plugins/environments/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py new file mode 100644 index 00000000000000..83d26cb0fcf91c --- /dev/null +++ b/tests/plugins/environments/test_lightning_environment.py @@ -0,0 +1,52 @@ +import os +from unittest import mock + +from pytorch_lightning.plugins.environments import LightningEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = LightningEnvironment() + assert not env.creates_children() + assert env.master_address() == "127.0.0.1" + assert isinstance(env.master_port(), int) + assert env.world_size() is None + assert env.local_rank() == 0 + assert env.node_rank() == 0 + + +@mock.patch.dict(os.environ, { + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "LOCAL_RANK": "2", + "NODE_RANK": "3", +}) +def test_attributes_from_environment_variables(): + """ Test that the default cluster environment takes the attributes from the environment variables. """ + env = LightningEnvironment() + assert env.master_address() == "1.2.3.4" + assert env.master_port() == 500 + assert env.world_size() is None + assert env.local_rank() == 2 + assert env.node_rank() == 3 + + +@mock.patch.dict(os.environ, { + "GROUP_RANK": "1", +}) +def test_node_rank_from_group_rank(): + """ Test that the GROUP_RANK substitutes NODE_RANK. """ + env = LightningEnvironment() + assert "NODE_RANK" not in os.environ + assert env.node_rank() == 1 + + +@mock.patch.dict(os.environ, {}) +def test_random_master_port(): + """ Test randomly chosen master port when no master port was given by user. """ + env = LightningEnvironment() + port = env.master_port() + assert isinstance(port, int) + # repeated calls do not generate a new port number + assert env.master_port() == port diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py new file mode 100644 index 00000000000000..8e82434846e68b --- /dev/null +++ b/tests/plugins/environments/test_slurm_environment.py @@ -0,0 +1,55 @@ +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import SLURMEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = SLURMEnvironment() + assert env.creates_children() + assert env.master_address() == "127.0.0.1" + assert env.master_port() == 12910 + assert env.world_size() is None + with pytest.raises(KeyError): + # local rank is required to be passed as env variable + env.local_rank() + with pytest.raises(KeyError): + # node_rank is required to be passed as env variable + env.node_rank() + + +@mock.patch.dict( + os.environ, { + "SLURM_NODELIST": "1.1.1.1, 1.1.1.2", + "SLURM_JOB_ID": "0001234", + "WORLD_SIZE": "20", + "SLURM_LOCALID": "2", + "SLURM_NODEID": "3", + } +) +def test_attributes_from_environment_variables(): + """ Test that the SLURM cluster environment takes the attributes from the environment variables. """ + env = SLURMEnvironment() + assert env.master_address() == "1.1.1.1" + assert env.master_port() == 15000 + 1234 + assert env.world_size() is None + assert env.local_rank() == 2 + assert env.node_rank() == 3 + + +@pytest.mark.parametrize( + "slurm_node_list,expected", [ + ("alpha,beta,gamma", "alpha"), + ("alpha beta gamma", "alpha"), + ("1.2.3.[100-110]", "1.2.3.100"), + ] +) +def test_master_address_from_slurm_node_list(slurm_node_list, expected): + """ Test extracting the master node from different formats for the SLURM_NODELIST. """ + with mock.patch.dict(os.environ, {"SLURM_NODELIST": slurm_node_list}): + env = SLURMEnvironment() + assert env.master_address() == expected diff --git a/tests/plugins/environments/test_torchelastic_environment.py b/tests/plugins/environments/test_torchelastic_environment.py new file mode 100644 index 00000000000000..55cfc25adde3c9 --- /dev/null +++ b/tests/plugins/environments/test_torchelastic_environment.py @@ -0,0 +1,39 @@ +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import TorchElasticEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = TorchElasticEnvironment() + assert env.creates_children() + assert env.master_address() == "127.0.0.1" + assert env.master_port() == 12910 + assert env.world_size() is None + with pytest.raises(KeyError): + # local rank is required to be passed as env variable + env.local_rank() + assert env.node_rank() == 0 + + +@mock.patch.dict( + os.environ, { + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "WORLD_SIZE": "20", + "LOCAL_RANK": "2", + "GROUP_RANK": "3", + } +) +def test_attributes_from_environment_variables(): + """ Test that the torchelastic cluster environment takes the attributes from the environment variables. """ + env = TorchElasticEnvironment() + assert env.master_address() == "1.2.3.4" + assert env.master_port() == 500 + assert env.world_size() == 20 + assert env.local_rank() == 2 + assert env.node_rank() == 3 diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py new file mode 100644 index 00000000000000..fc3cd543272883 --- /dev/null +++ b/tests/plugins/test_amp_plugins.py @@ -0,0 +1,84 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class MyNativeAMP(NativeMixedPrecisionPlugin): + pass + + +class MyApexPlugin(ApexMixedPrecisionPlugin): + pass + + +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + } +) +@mock.patch('torch.cuda.device_count', return_value=2) +@pytest.mark.parametrize('ddp_backend,gpus', [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)]) +@pytest.mark.parametrize( + 'amp,custom_plugin,plugin_cls', [ + pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)), + pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)), + pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)), + pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)) + ] +) +def test_amp_apex_ddp( + mocked_device_count, ddp_backend: str, gpus: int, amp: str, custom_plugin: bool, plugin_cls: MixedPrecisionPlugin +): + + trainer = Trainer( + fast_dev_run=True, + precision=16, + amp_backend=amp, + gpus=gpus, + accelerator=ddp_backend, + plugins=[plugin_cls()] if custom_plugin else None, + ) + assert isinstance(trainer.precision_plugin, plugin_cls) + + +class GradientUnscaleBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + +@RunIf(min_gpus=2, amp_native=True) +@pytest.mark.parametrize('accum', [1, 2]) +def test_amp_gradient_unscale(tmpdir, accum: int): + model = GradientUnscaleBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=tmpdir, + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + accelerator='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1, + accumulate_grad_batches=accum, + ) + trainer.fit(model) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py new file mode 100644 index 00000000000000..872b49ef486356 --- /dev/null +++ b/tests/plugins/test_custom_plugin.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDPPlugin +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class CustomParallelPlugin(DDPPlugin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Set to None so it will be overwritten by the accelerator connector. + self.sync_batchnorm = None + + +@RunIf(skip_windows=True) +def test_sync_batchnorm_set(tmpdir): + """Tests if sync_batchnorm is automatically set for custom plugin.""" + model = BoringModel() + plugin = CustomParallelPlugin() + assert plugin.sync_batchnorm is None + trainer = Trainer( + max_epochs=1, + plugins=[plugin], + default_root_dir=tmpdir, + sync_batchnorm=True, + ) + trainer.fit(model) + assert plugin.sync_batchnorm is True diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py new file mode 100644 index 00000000000000..e6b15069f256af --- /dev/null +++ b/tests/plugins/test_deepspeed_plugin.py @@ -0,0 +1,357 @@ +import json +import os + +import pytest +import torch +from torch import Tensor +from torch.optim import Optimizer + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin +from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +def test_deepspeed_lightning_module(tmpdir): + """ + Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly. + """ + + model = BoringModel() + module = LightningDeepSpeedModule(model, precision=16) + + module.half() + assert module.dtype == torch.half + assert model.dtype == torch.half + + module.to(torch.double) + assert module.dtype == torch.double + assert model.dtype == torch.double + + +@RunIf(min_gpus=1) +def test_deepspeed_lightning_module_precision(tmpdir): + """ + Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16. + """ + + model = BoringModel() + module = LightningDeepSpeedModule(model, precision=16) + + module.cuda().half() + assert module.dtype == torch.half + assert model.dtype == torch.half + + x = torch.randn((1, 32), dtype=torch.float).cuda() + out = module(x) + + assert out.dtype == torch.half + + module.to(torch.double) + assert module.dtype == torch.double + assert model.dtype == torch.double + + +@pytest.fixture +def deepspeed_config(): + return { + "optimizer": { + "type": "SGD", + "params": { + "lr": 3e-5, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + } + } + + +@pytest.fixture +def deepspeed_zero_config(deepspeed_config): + return {**deepspeed_config, 'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2}} + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize("input", ("deepspeed", DeepSpeedPlugin)) +def test_deepspeed_plugin_string(tmpdir, input): + """ + Test to ensure that the plugin can be passed via string or instance, and parallel devices is correctly set. + """ + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins=input if isinstance(input, str) else input(), + ) + + assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator.training_type_plugin.parallel_devices == [torch.device('cpu')] + + +@RunIf(deepspeed=True) +def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): + """ + Test to ensure that the plugin can be passed via a string with an environment variable. + """ + config_path = os.path.join(tmpdir, 'temp.json') + with open(config_path, 'w') as f: + f.write(json.dumps(deepspeed_config)) + monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins='deepspeed', + ) + + plugin = trainer.accelerator.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert plugin.parallel_devices == [torch.device('cpu')] + assert plugin.config == deepspeed_config + + +@RunIf(amp_native=True, deepspeed=True) +@pytest.mark.parametrize( + "amp_backend", [ + pytest.param("native", marks=RunIf(amp_native=True)), + pytest.param("apex", marks=RunIf(amp_apex=True)), + ] +) +def test_deepspeed_precision_choice(amp_backend, tmpdir): + """ + Test to ensure precision plugin is also correctly chosen. + DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin + """ + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins='deepspeed', + amp_backend=amp_backend, + precision=16, + ) + + assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) + assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.accelerator.precision_plugin.precision == 16 + + +@RunIf(deepspeed=True) +def test_deepspeed_with_invalid_config_path(tmpdir): + """ + Test to ensure if we pass an invalid config path we throw an exception. + """ + + with pytest.raises( + MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" + ): + DeepSpeedPlugin(config='invalid_path.json') + + +@RunIf(deepspeed=True) +def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): + """ + Test to ensure if we pass an env variable, we load the config from the path. + """ + config_path = os.path.join(tmpdir, 'temp.json') + with open(config_path, 'w') as f: + f.write(json.dumps(deepspeed_config)) + monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) + plugin = DeepSpeedPlugin() + assert plugin.config == deepspeed_config + + +@RunIf(deepspeed=True) +def test_deepspeed_defaults(tmpdir): + """ + Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed. + """ + plugin = DeepSpeedPlugin() + assert plugin.config is not None + assert isinstance(plugin.config["zero_optimization"], dict) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_invalid_deepspeed_defaults_no_precision(tmpdir): + """Test to ensure that using defaults, if precision is not set to 16, we throw an exception.""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins='deepspeed', + ) + with pytest.raises( + MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' + ): + trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_warn_deepspeed_override_backward(tmpdir): + """Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.""" + + class TestModel(BoringModel): + + def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + return loss.backward() + + model = TestModel() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + plugins=DeepSpeedPlugin(), + gpus=1, + precision=16, + ) + with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): + trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_run_configure_optimizers(tmpdir): + """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), + whilst using configure_optimizers for optimizers and schedulers.""" + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer + + assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) + assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally + # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler + assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR) + + model = TestModel() + trainer = Trainer( + plugins=DeepSpeedPlugin(), # disable ZeRO so our optimizers are not wrapped + default_root_dir=tmpdir, + gpus=1, + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + + _assert_save_model_is_equal(model, tmpdir, trainer) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_config(tmpdir, deepspeed_zero_config): + """ + Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers + and saves the model weights to load correctly. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + from deepspeed.runtime.lr_schedules import WarmupLR + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer + + assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) + assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally + # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler + assert isinstance(self.trainer.model.lr_scheduler, WarmupLR) + + model = TestModel() + trainer = Trainer( + plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], + default_root_dir=tmpdir, + gpus=1, + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + trainer.test(model) + + _assert_save_model_is_equal(model, tmpdir, trainer) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_custom_precision_params(tmpdir): + """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.""" + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.training_type_plugin.config['fp16']['loss_scale'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['initial_scale_power'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['loss_scale_window'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['hysteresis'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['min_loss_scale'] == 10 + raise SystemExit() + + model = TestModel() + ds = DeepSpeedPlugin(loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10) + trainer = Trainer(default_root_dir=tmpdir, plugins=[ds], precision=16, gpus=1) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): + """Ensure if we use a config and turn off cpu_offload, that this is set to False within the config.""" + + deepspeed_zero_config['zero_optimization']['cpu_offload'] = False + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.training_type_plugin.config['zero_optimization']['cpu_offload'] is False + raise SystemExit() + + model = TestModel() + trainer = Trainer( + plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], + precision=16, + gpus=1, + default_root_dir=tmpdir, + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(min_gpus=2, special=True, deepspeed=True) +def test_deepspeed_multigpu(tmpdir, deepspeed_config): + """ + Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation. + """ + model = BoringModel() + trainer = Trainer( + plugins=[DeepSpeedPlugin()], + default_root_dir=tmpdir, + gpus=2, + fast_dev_run=True, + precision=16, + ) + trainer.fit(model) + trainer.test(model) + + _assert_save_model_is_equal(model, tmpdir, trainer) + + +def _assert_save_model_is_equal(model, tmpdir, trainer): + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + # carry out the check only on rank 0 + if trainer.global_rank == 0: + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + if model.dtype == torch.half: + saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16 + model = model.cpu() + # Assert model parameters are identical after loading + for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(orig_param, trained_model_param) diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py new file mode 100644 index 00000000000000..f089b1c23149eb --- /dev/null +++ b/tests/plugins/test_double_plugin.py @@ -0,0 +1,129 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel, RandomDataset + + +class RandomFloatIntDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.float_data = torch.randn(length, size) + self.int_data = torch.randint(10, (length, 1)) + + def __getitem__(self, index): + return self.float_data[index], self.int_data[index] + + def __len__(self): + return self.len + + +class DoublePrecisionBoringModel(BoringModel): + + def training_step(self, batch, batch_idx): + float_data, int_data = batch + assert float_data.dtype == torch.float64 + output = self(float_data) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def test_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.dtype == torch.float64 + return self(batch) + + def on_fit_start(self): + assert self.layer.weight.dtype == torch.float64 + + def on_after_backward(self): + assert self.layer.weight.grad.dtype == torch.float64 + + def train_dataloader(self): + dataset = RandomFloatIntDataset(32, 64) + assert dataset.float_data.dtype == torch.float32 # Don't start with double data + return DataLoader(dataset) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class DoublePrecisionBoringModelNoForward(BoringModel): + + def training_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + loss = self.loss(batch, output) + return {"x": loss} + + def test_step(self, batch, batch_idx): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + loss = self.loss(batch, output) + return {"y": loss} + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.dtype == torch.float64 + output = self.layer(batch) + assert output.dtype == torch.float64 + return output + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +@pytest.mark.parametrize( + 'boring_model', + (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward) +) +def test_double_precision(tmpdir, boring_model): + model = boring_model() + original_training_step = model.training_step + + trainer = Trainer( + max_epochs=2, + default_root_dir=tmpdir, + fast_dev_run=2, + precision=64, + log_every_n_steps=1, + ) + trainer.fit(model) + trainer.test(model) + trainer.predict(model) + + assert model.training_step == original_training_step diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py new file mode 100644 index 00000000000000..9ecc93a9b50555 --- /dev/null +++ b/tests/plugins/test_rpc_plugin.py @@ -0,0 +1,88 @@ +import os +from typing import Optional +from unittest import mock + +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCPlugin +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp_spawn", 2, 0)], +) +@RunIf(rpc=True) +def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes): + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.training_type_plugin, RPCPlugin) + raise RuntimeError('finished plugin check') + + model = BoringModel() + trainer = Trainer( + default_root_dir=str(tmpdir), + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + callbacks=[CB()], + plugins=[RPCPlugin()] + ) + + with pytest.raises(RuntimeError, match='finished plugin check'): + trainer.fit(model) + + +class CustomRPCPlugin(RPCPlugin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.rpc_save_model_count = 0 + self.worker_optimizer_step_count = 0 + + def rpc_save_model(self, *_) -> None: + self.rpc_save_model_count += 1 + + def barrier(self, name: Optional[str] = None) -> None: + return + + +@RunIf(min_gpus=2, special=True, rpc=True) +def test_rpc_function_calls_ddp(tmpdir): + model = BoringModel() + plugin = CustomRPCPlugin() + max_epochs = 2 + limit_train_batches = 2 + trainer = Trainer( + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=max_epochs, + gpus=2, + distributed_backend='ddp', + plugins=[plugin], + default_root_dir=tmpdir, + ) + + trainer.fit(model) + if trainer.global_rank == 0: # Main process + assert plugin.rpc_save_model_count == max_epochs + else: # Worker process + assert plugin.rpc_save_model_count == max_epochs diff --git a/tests/plugins/test_rpc_sequential_plugin.py b/tests/plugins/test_rpc_sequential_plugin.py new file mode 100644 index 00000000000000..688424e0f74fb2 --- /dev/null +++ b/tests/plugins/test_rpc_sequential_plugin.py @@ -0,0 +1,187 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import pytest +import torch +import torch.distributed as torch_distrib +from torch import nn + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import RandomDataset +from tests.helpers.runif import RunIf + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=2, special=True, fairscale_pipe=True) +def test_rpc_sequential_plugin_manual(tmpdir, args=None): + model = SequentialModelRPCManual() + trainer = Trainer( + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + gpus=2, + distributed_backend="ddp", + plugins=[RPCSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], + ) + + trainer.fit(model) + + if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0: + assert len(trainer.dev_debugger.pbar_added_metrics) > 0 + + if trainer.accelerator.rpc_enabled: + # Called at the end of trainer to ensure all processes are killed + trainer.accelerator.training_type_plugin.exit_rpc_process() + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=2, special=True, fairscale_pipe=True) +def test_rpc_sequential_plugin_manual_amp(tmpdir, args=None): + model = SequentialModelRPCManual() + trainer = Trainer( + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + gpus=2, + precision=16, + amp_backend="native", + distributed_backend="ddp", + plugins=[RPCSequentialPlugin(balance=[2, 1])], + ) + with pytest.raises( + MisconfigurationException, + match='`RPCSequentialPlugin` is currently not supported in Automatic Mixed Precision' + ): + trainer.fit(model) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=2, special=True, fairscale_pipe=True) +def test_rpc_sequential_plugin_automatic(tmpdir, args=None): + model = SequentialModelRPCAutomatic() + trainer = Trainer( + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + gpus=2, + distributed_backend="ddp", + plugins=[RPCSequentialPlugin(balance=[2, 1])], + ) + + trainer.fit(model) + + if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0: + assert len(trainer.dev_debugger.pbar_added_metrics) > 0 + + if trainer.accelerator.rpc_enabled: + # Called at the end of trainer to ensure all processes are killed + trainer.accelerator.training_type_plugin.exit_rpc_process() + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=2, special=True, fairscale_pipe=True) +def test_rpc_sequential_plugin_with_wrong_balance(tmpdir, args=None): + model = SequentialModelRPCAutomatic() + trainer = Trainer( + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + gpus=2, + distributed_backend="ddp", + plugins=[RPCSequentialPlugin(balance=[2, 2])], + ) + + with pytest.raises( + MisconfigurationException, match="The provided balance sum: 4 does not match your Sequential length: 3" + ): + trainer.fit(model) + + if trainer.accelerator.rpc_enabled: + # Called at the end of trainer to ensure all processes are killed + trainer.accelerator.training_type_plugin.exit_rpc_process() + + +class SequentialModelRPCManual(LightningModule): + + def __init__(self): + super().__init__() + self.sequential_module = nn.Sequential(torch.nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) + self.automatic_optimization = False + + def forward(self, x): + return self.sequential_module(x) + + def loss(self, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + output = self.sequential_module(batch) + loss = self.loss(output) + self.log("train_loss", loss, on_epoch=True, prog_bar=True) + self.manual_backward(loss, opt) + assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0 + opt.step() + opt.zero_grad() + assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0 + + def validation_step(self, batch, batch_idx): + output = self.sequential_module(batch) + loss = self.loss(output) + return loss + + def test_step(self, batch, batch_idx): + output = self.sequential_module(batch) + return self.loss(batch, output) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +class SequentialModelRPCAutomatic(SequentialModelRPCManual): + + def __init__(self): + super().__init__() + self.automatic_optimization = True + + def training_step(self, batch, batch_idx): + output = self.sequential_module(batch) + loss = self.loss(output) + self.log("train_loss", loss, on_epoch=True, prog_bar=True) + return loss diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py new file mode 100644 index 00000000000000..655e12f046e04d --- /dev/null +++ b/tests/plugins/test_sharded_plugin.py @@ -0,0 +1,280 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("clip_val", [0, 10]) +@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True) +@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm') +def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + model = BoringModel() + trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) + trainer.fit(model) + if clip_val > 0: + mock_oss_clip_grad_norm.assert_called() + else: + mock_oss_clip_grad_norm.assert_not_called() + + +@RunIf(fairscale=True) +@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) +def test_sharded_ddp_choice(tmpdir, accelerator): + """ + Test to ensure that plugin is correctly chosen + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + if accelerator == 'ddp_sharded': + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + elif accelerator == 'ddp_sharded_spawn': + assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator=accelerator, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(amp_apex=True, fairscale=True) +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + trainer = Trainer( + fast_dev_run=True, + accelerator='ddp_sharded_spawn', + precision=16, + amp_backend='apex', + ) + + trainer.fit(model) + + +@RunIf(min_gpus=2, amp_native=True, fairscale=True) +@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) +def test_ddp_choice_sharded_amp(tmpdir, accelerator): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + if accelerator == 'ddp_sharded': + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + elif accelerator == 'ddp_sharded_spawn': + assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=1, + precision=16, + accelerator=accelerator, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(skip_windows=True, fairscale=True) +def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param.to("cpu"), shard_param) + + +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) +def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param.to("cpu"), shard_param) + + +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) +def test_ddp_sharded_plugin_finetune(tmpdir): + """ + Test to ensure that we can save and restart training (simulate fine-tuning) + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + ) + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + trainer = Trainer(fast_dev_run=True, ) + trainer.fit(saved_model) + + +@RunIf(skip_windows=True, fairscale=True) +def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): + """ + Test to ensure that resuming from checkpoint works + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path, + ) + + trainer.fit(model) + + +@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") # todo +@pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) +def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): + """ + Test to ensure that resuming from checkpoint works when downsizing number of GPUS + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + gpus=2, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + gpus=1, + resume_from_checkpoint=checkpoint_path, + ) + + trainer.fit(model) + + +@RunIf(min_gpus=1, skip_windows=True, fairscale=True) +def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): + """ + Test to ensure that resuming from checkpoint works when going from GPUs- > CPU + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + gpus=1, + fast_dev_run=True, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + model = BoringModel() + + trainer = Trainer( + accelerator='ddp_sharded_spawn', + num_processes=2, + fast_dev_run=True, + resume_from_checkpoint=checkpoint_path, + ) + + trainer.fit(model) + + +@RunIf(skip_windows=True, special=True, fairscale=True) +@pytest.mark.parametrize( + "trainer_kwargs", ( + dict(num_processes=2), + pytest.param(dict(gpus=2), marks=RunIf(min_gpus=2)), + ) +) +def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): + """ + Test to ensure we can use validate and test without fit + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded_spawn', + fast_dev_run=True, + **trainer_kwargs, + ) + + trainer.validate(model) + trainer.test(model) diff --git a/tests/requirements-devel.txt b/tests/requirements-devel.txt deleted file mode 100644 index 2ff897d29e21b8..00000000000000 --- a/tests/requirements-devel.txt +++ /dev/null @@ -1,5 +0,0 @@ -# install all extra dependencies for full package testing --r ../requirements-extra.txt - -# extended list of dependencies dor development and run lint and tests --r ./requirements.txt \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt deleted file mode 100644 index 79a10e3ea55688..00000000000000 --- a/tests/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -coverage -codecov -pytest>=3.0.5 -pytest-cov -pytest-flake8 -flake8 -check-manifest -twine==1.13.0 \ No newline at end of file diff --git a/tests/special_tests.sh b/tests/special_tests.sh new file mode 100755 index 00000000000000..aa5d65844a1c52 --- /dev/null +++ b/tests/special_tests.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +# this environment variable allows special tests to run +export PL_RUNNING_SPECIAL_TESTS=1 +# python arguments +defaults='-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no' + +# find tests marked as `@RunIf(special=True)` +grep_output=$(grep --recursive --line-number --word-regexp 'tests' 'benchmarks' --regexp 'special=True') +# file paths +files=$(echo "$grep_output" | cut -f1 -d:) +files_arr=($files) +# line numbers +linenos=$(echo "$grep_output" | cut -f2 -d:) +linenos_arr=($linenos) + +# tests to skip - space separated +blocklist='test_pytorch_profiler_nested_emit_nvtx' +report='' + +for i in "${!files_arr[@]}"; do + file=${files_arr[$i]} + lineno=${linenos_arr[$i]} + + # get code from `@RunIf(special=True)` line to EOF + test_code=$(tail -n +"$lineno" "$file") + + # read line by line + while read -r line; do + # if it's a test + if [[ $line == def\ test_* ]]; then + # get the name + test_name=$(echo $line | cut -c 5- | cut -f1 -d\() + + # check blocklist + if echo $blocklist | grep --word-regexp "$test_name" > /dev/null; then + report+="Skipped\t$file:$lineno::$test_name\n" + break + fi + + # run the test + report+="Ran\t$file:$lineno::$test_name\n" + python ${defaults} "${file}::${test_name}" + break + fi + done < <(echo "$test_code") +done + +nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx + +# echo test report +printf '=%.s' {1..80} +printf "\n$report" +printf '=%.s' {1..80} +printf '\n' diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py deleted file mode 100644 index 437e5f35ab77f3..00000000000000 --- a/tests/test_deprecated.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Test deprecated functionality which will be removed in vX.Y.Z""" -import sys - -import pytest - -from pytorch_lightning import Trainer - -import tests.base.utils as tutils -from tests.base import TestModelBase, LightTrainDataloader, LightEmptyTestStep - - -def _soft_unimport_module(str_module): - # once the module is imported e.g with parsing with pytest it lives in memory - if str_module in sys.modules: - del sys.modules[str_module] - - -def test_tbd_remove_in_v0_8_0_module_imports(): - _soft_unimport_module("pytorch_lightning.logging.comet_logger") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.logging.comet_logger import CometLogger # noqa: F811 - _soft_unimport_module("pytorch_lightning.logging.mlflow_logger") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.logging.mlflow_logger import MLFlowLogger # noqa: F811 - _soft_unimport_module("pytorch_lightning.logging.test_tube_logger") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.logging.test_tube_logger import TestTubeLogger # noqa: F811 - - _soft_unimport_module("pytorch_lightning.pt_overrides.override_data_parallel") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.pt_overrides.override_data_parallel import ( # noqa: F811 - LightningDataParallel, LightningDistributedDataParallel) - _soft_unimport_module("pytorch_lightning.overrides.override_data_parallel") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.overrides.override_data_parallel import ( # noqa: F811 - LightningDataParallel, LightningDistributedDataParallel) - - _soft_unimport_module("pytorch_lightning.core.model_saving") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.core.model_saving import ModelIO # noqa: F811 - _soft_unimport_module("pytorch_lightning.core.root_module") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.core.root_module import LightningModule # noqa: F811 - - _soft_unimport_module("pytorch_lightning.root_module.decorators") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.root_module.decorators import data_loader # noqa: F811 - _soft_unimport_module("pytorch_lightning.root_module.grads") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.root_module.grads import GradInformation # noqa: F811 - _soft_unimport_module("pytorch_lightning.root_module.hooks") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.root_module.hooks import ModelHooks # noqa: F811 - _soft_unimport_module("pytorch_lightning.root_module.memory") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.root_module.memory import ModelSummary # noqa: F811 - _soft_unimport_module("pytorch_lightning.root_module.model_saving") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.root_module.model_saving import ModelIO # noqa: F811 - _soft_unimport_module("pytorch_lightning.root_module.root_module") - with pytest.deprecated_call(match='v0.8.0'): - from pytorch_lightning.root_module.root_module import LightningModule # noqa: F811 - - -def test_tbd_remove_in_v0_8_0_trainer(): - mapping_old_new = { - 'gradient_clip': 'gradient_clip_val', - 'nb_gpu_nodes': 'num_nodes', - 'max_nb_epochs': 'max_epochs', - 'min_nb_epochs': 'min_epochs', - 'nb_sanity_val_steps': 'num_sanity_val_steps', - 'default_save_path': 'default_root_dir', - } - # skip 0 since it may be interested as False - kwargs = {k: (i + 1) for i, k in enumerate(mapping_old_new)} - - trainer = Trainer(**kwargs) - - for attr_old in mapping_old_new: - attr_new = mapping_old_new[attr_old] - with pytest.deprecated_call(match='v0.8.0'): - _ = getattr(trainer, attr_old) - assert kwargs[attr_old] == getattr(trainer, attr_old), \ - 'Missing deprecated attribute "%s"' % attr_old - assert kwargs[attr_old] == getattr(trainer, attr_new), \ - 'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new) - - -def test_tbd_remove_in_v0_9_0_trainer(): - # test show_progress_bar set by progress_bar_refresh_rate - with pytest.deprecated_call(match='v0.9.0'): - trainer = Trainer(progress_bar_refresh_rate=0, show_progress_bar=True) - assert not getattr(trainer, 'show_progress_bar') - - with pytest.deprecated_call(match='v0.9.0'): - trainer = Trainer(progress_bar_refresh_rate=50, show_progress_bar=False) - assert getattr(trainer, 'show_progress_bar') - - -def test_tbd_remove_in_v0_9_0_module_imports(): - _soft_unimport_module("pytorch_lightning.core.decorators") - with pytest.deprecated_call(match='v0.9.0'): - from pytorch_lightning.core.decorators import data_loader # noqa: F811 - data_loader(print) - - _soft_unimport_module("pytorch_lightning.logging.comet") - with pytest.deprecated_call(match='v0.9.0'): - from pytorch_lightning.logging.comet import CometLogger # noqa: F402 - _soft_unimport_module("pytorch_lightning.logging.mlflow") - with pytest.deprecated_call(match='v0.9.0'): - from pytorch_lightning.logging.mlflow import MLFlowLogger # noqa: F402 - _soft_unimport_module("pytorch_lightning.logging.neptune") - with pytest.deprecated_call(match='v0.9.0'): - from pytorch_lightning.logging.neptune import NeptuneLogger # noqa: F402 - _soft_unimport_module("pytorch_lightning.logging.test_tube") - with pytest.deprecated_call(match='v0.9.0'): - from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402 - _soft_unimport_module("pytorch_lightning.logging.wandb") - with pytest.deprecated_call(match='v0.9.0'): - from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 - - -class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): - - # todo: this shall not be needed while evaluate asks for dataloader explicitly - def val_dataloader(self): - return self._dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - return {'val_loss': 0.6} - - def validation_end(self, outputs): - return {'val_loss': 0.6} - - def test_dataloader(self): - return self._dataloader(train=False) - - def test_end(self, outputs): - return {'test_loss': 0.6} - - -class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase): - - # todo: this shall not be needed while evaluate asks for dataloader explicitly - def val_dataloader(self): - return self._dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - return {'val_loss': 0.7} - - def validation_end(self, outputs): - return {'val_loss': 0.7} - - def test_dataloader(self): - return self._dataloader(train=False) - - def test_end(self, outputs): - return {'test_loss': 0.7} - - -def test_tbd_remove_in_v1_0_0_model_hooks(): - hparams = tutils.get_default_hparams() - - model = ModelVer0_6(hparams) - - with pytest.deprecated_call(match='v1.0'): - trainer = Trainer(logger=False) - trainer.test(model) - assert trainer.callback_metrics == {'test_loss': 0.6} - - with pytest.deprecated_call(match='v1.0'): - trainer = Trainer(logger=False) - # TODO: why `dataloder` is required if it is not used - result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) - assert result == {'val_loss': 0.6} - - model = ModelVer0_7(hparams) - - with pytest.deprecated_call(match='v1.0'): - trainer = Trainer(logger=False) - trainer.test(model) - assert trainer.callback_metrics == {'test_loss': 0.7} - - with pytest.deprecated_call(match='v1.0'): - trainer = Trainer(logger=False) - # TODO: why `dataloder` is required if it is not used - result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) - assert result == {'val_loss': 0.7} diff --git a/tests/test_profiler.py b/tests/test_profiler.py index fa9fc103f0b46e..a6e33b3366f33e 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,12 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging import os +import platform import time -from pathlib import Path +from copy import deepcopy +from distutils.version import LooseVersion import numpy as np import pytest -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler +import torch -PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler.pytorch import RegisterRecordFunction +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + +PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 def _get_python_cprofile_total_duration(profile): @@ -25,22 +49,15 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = SimpleProfiler() - return profiler - - -@pytest.fixture -def advanced_profiler(tmpdir): - profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) - return profiler + return SimpleProfiler() @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), - pytest.param("c", [1]) + pytest.param("c", [1]), ]) -def test_simple_profiler_durations(simple_profiler, action, expected): +def test_simple_profiler_durations(simple_profiler, action: str, expected: list): """Ensure the reported durations are reasonably accurate.""" for duration in expected: @@ -49,17 +66,15 @@ def test_simple_profiler_durations(simple_profiler, action, expected): # different environments have different precision when it comes to time.sleep() # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 - np.testing.assert_allclose( - simple_profiler.recorded_durations[action], expected, rtol=0.2 - ) + np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2) @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), - pytest.param("c", [1]) + pytest.param("c", [1]), ]) -def test_simple_profiler_iterable_durations(simple_profiler, action, expected): +def test_simple_profiler_iterable_durations(simple_profiler, action: str, expected: list): """Ensure the reported durations are reasonably accurate.""" iterable = _sleep_generator(expected) @@ -67,9 +82,7 @@ def test_simple_profiler_iterable_durations(simple_profiler, action, expected): pass # we exclude the last item in the recorded durations since that's when StopIteration is raised - np.testing.assert_allclose( - simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2 - ) + np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2) def test_simple_profiler_overhead(simple_profiler, n_iter=5): @@ -82,13 +95,6 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) -def test_simple_profiler_describe(caplog, simple_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - simple_profiler.describe() - - assert "Profiler Report" in caplog.text - - def test_simple_profiler_value_errors(simple_profiler): """Ensure errors are raised where expected.""" @@ -104,12 +110,83 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +def test_simple_profiler_deepcopy(tmpdir): + simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test") + simple_profiler.describe() + assert deepcopy(simple_profiler) + + +def test_simple_profiler_log_dir(tmpdir): + """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" + profiler = SimpleProfiler(filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + profiler=profiler, + ) + trainer.fit(model) + + expected = tmpdir / "lightning_logs" / "version_0" + assert trainer.log_dir == expected + assert profiler._log_dir == trainer.log_dir + assert expected.join("fit-profiler.txt").exists() + + +@RunIf(skip_windows=True) +def test_simple_profiler_distributed_files(tmpdir): + """Ensure the proper files are saved in distributed""" + profiler = SimpleProfiler(dirpath=tmpdir, filename='profiler') + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + accelerator="ddp_cpu", + num_processes=2, + profiler=profiler, + logger=False, + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + actual = set(os.listdir(profiler.dirpath)) + expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)} + assert actual == expected + + for f in profiler.dirpath.listdir(): + assert f.read_text('utf-8') + + +def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): + """Ensure that the number of printed logs is correct""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + profiler=simple_profiler, + logger=False, + ) + with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): + trainer.fit(model) + trainer.test(model) + + assert caplog.text.count("Profiler Report") == 2 + + +@pytest.fixture +def advanced_profiler(tmpdir): + return AdvancedProfiler(dirpath=tmpdir, filename="profiler") + + @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), - pytest.param("c", [1]) + pytest.param("c", [1]), ]) -def test_advanced_profiler_durations(advanced_profiler, action, expected): +def test_advanced_profiler_durations(advanced_profiler, action: str, expected: list): for duration in expected: with advanced_profiler.profile(action): @@ -117,34 +194,26 @@ def test_advanced_profiler_durations(advanced_profiler, action, expected): # different environments have different precision when it comes to time.sleep() # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 - recored_total_duration = _get_python_cprofile_total_duration( - advanced_profiler.profiled_actions[action] - ) + recored_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action]) expected_total_duration = np.sum(expected) - np.testing.assert_allclose( - recored_total_duration, expected_total_duration, rtol=0.2 - ) + np.testing.assert_allclose(recored_total_duration, expected_total_duration, rtol=0.2) @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), - pytest.param("c", [1]) + pytest.param("c", [1]), ]) -def test_advanced_profiler_iterable_durations(advanced_profiler, action, expected): +def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, expected: list): """Ensure the reported durations are reasonably accurate.""" iterable = _sleep_generator(expected) for _ in advanced_profiler.profile_iterable(iterable, action): pass - recored_total_duration = _get_python_cprofile_total_duration( - advanced_profiler.profiled_actions[action] - ) + recored_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action]) expected_total_duration = np.sum(expected) - np.testing.assert_allclose( - recored_total_duration, expected_total_duration, rtol=0.2 - ) + np.testing.assert_allclose(recored_total_duration, expected_total_duration, rtol=0.2) def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): @@ -170,7 +239,8 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): pass # log to stdout and print to file advanced_profiler.describe() - data = Path(advanced_profiler.output_fname).read_text() + path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -183,3 +253,259 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.start(action) advanced_profiler.stop(action) + + +def test_advanced_profiler_deepcopy(advanced_profiler): + advanced_profiler.describe() + assert deepcopy(advanced_profiler) + + +@pytest.fixture +def pytorch_profiler(tmpdir): + return PyTorchProfiler(dirpath=tmpdir, filename="profiler") + + +@RunIf(max_torch="1.8.1") +def test_pytorch_profiler_describe(pytorch_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with pytorch_profiler.profile("on_test_start"): + torch.tensor(0) + + # log to stdout and print to file + pytorch_profiler.describe() + path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt" + data = path.read_text("utf-8") + assert len(data) > 0 + + +def test_pytorch_profiler_raises(pytorch_profiler): + """Ensure errors are raised where expected.""" + with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"): + PyTorchProfiler(profiled_functions=["a"], record_functions=["b"]) + + +@RunIf(min_torch="1.6.0") +def test_advanced_profiler_cprofile_deepcopy(tmpdir): + """Checks for pickle issue reported in #6522""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + stochastic_weight_avg=True, + ) + trainer.fit(model) + + +@RunIf(min_gpus=2, special=True) +def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=5, + profiler=pytorch_profiler, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + expected = {'validation_step'} + if not _KINETO_AVAILABLE: + expected |= {'training_step_and_backward', 'training_step', 'backward'} + for name in expected: + assert sum(e.name == name for e in pytorch_profiler.function_events), name + + files = set(os.listdir(pytorch_profiler.dirpath)) + expected = f"fit-profiler-{trainer.local_rank}.txt" + assert expected in files + + path = pytorch_profiler.dirpath / expected + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = os.listdir(pytorch_profiler.dirpath) + files = [file for file in files if file.endswith('.json')] + assert len(files) == 2, files + local_rank = trainer.local_rank + assert any(f'training_step_{local_rank}' in f for f in files) + assert any(f'validation_step_{local_rank}' in f for f in files) + + +def test_pytorch_profiler_trainer_test(tmpdir): + """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=2, + profiler=pytorch_profiler, + ) + trainer.test(model) + + assert sum(e.name == 'test_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = sorted([file for file in os.listdir(tmpdir) if file.endswith('.json')]) + assert any(f'test_step_{trainer.local_rank}' in f for f in files) + + +def test_pytorch_profiler_trainer_predict(tmpdir): + """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + model.predict_dataloader = model.train_dataloader + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_predict_batches=2, + profiler=pytorch_profiler, + ) + trainer.predict(model) + + assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) + path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_trainer_validate(tmpdir): + """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=2, + profiler=pytorch_profiler, + ) + trainer.validate(model) + + assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_nested(tmpdir): + """Ensure that the profiler handles nested context""" + + pytorch_profiler = PyTorchProfiler( + record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None + ) + + with pytorch_profiler.profile("a"): + a = torch.ones(42) + with pytorch_profiler.profile("b"): + b = torch.zeros(42) + with pytorch_profiler.profile("c"): + _ = a + b + + pytorch_profiler.describe() + + events_name = {e.name for e in pytorch_profiler.function_events} + + if platform.system() == "Windows": + expected = {'a', 'add', 'b', 'c', 'profiler::_record_function_enter', 'profiler::_record_function_exit'} + else: + expected = { + 'signed char', 'add', 'profiler::_record_function_exit', 'bool', 'char', 'profiler::_record_function_enter' + } + + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + expected = {'add', 'zeros', 'ones', 'zero_', 'b', 'fill_', 'c', 'a', 'empty'} + + if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"): + expected = { + 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' + } + + assert events_name == expected, (events_name, torch.__version__, platform.system()) + + +@RunIf(min_gpus=1, special=True) +def test_pytorch_profiler_nested_emit_nvtx(tmpdir): + """ + This test check emit_nvtx is correctly supported + """ + profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + gpus=1, + ) + trainer.fit(model) + + +@RunIf(min_torch="1.5.0") +def test_register_record_function(tmpdir): + + use_cuda = torch.cuda.is_available() + pytorch_profiler = PyTorchProfiler( + export_to_chrome=False, + record_functions={"a"}, + use_cuda=use_cuda, + dirpath=tmpdir, + filename="profiler", + schedule=None, + on_trace_ready=None, + ) + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1)) + + model = TestModel() + input = torch.rand((1, 1)) + + if use_cuda: + model = model.cuda() + input = input.cuda() + + with pytorch_profiler.profile("a"): + with RegisterRecordFunction(model): + model(input) + + pytorch_profiler.describe() + event_names = [e.name for e in pytorch_profiler.function_events] + assert 'torch.nn.modules.container.Sequential: layer' in event_names + assert 'torch.nn.modules.linear.Linear: layer.0' in event_names + assert 'torch.nn.modules.activation.ReLU: layer.1' in event_names + assert 'torch.nn.modules.linear.Linear: layer.2' in event_names + + +@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_profiler_teardown(tmpdir, cls): + """ + This test checks if profiler teardown method is called when trainer is exiting. + """ + + class TestCallback(Callback): + + def on_fit_end(self, trainer, *args, **kwargs) -> None: + # describe sets it to None + assert trainer.profiler._output_file is None + + profiler = cls(dirpath=tmpdir, filename="profiler") + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) + trainer.fit(model) + + assert profiler._output_file is None + + +def test_pytorch_profiler_deepcopy(tmpdir): + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None) + pytorch_profiler.start("on_train_start") + torch.tensor(1) + pytorch_profiler.describe() + assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/connectors/__init__.py b/tests/trainer/connectors/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py new file mode 100644 index 00000000000000..34149e2231bf53 --- /dev/null +++ b/tests/trainer/connectors/test_callback_connector.py @@ -0,0 +1,141 @@ +import logging +from unittest.mock import Mock + +import torch + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ( + EarlyStopping, + GradientAccumulationScheduler, + LearningRateMonitor, + ModelCheckpoint, + ProgressBar, +) +from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector +from tests.helpers import BoringModel + + +def test_checkpoint_callbacks_are_last(tmpdir): + """ Test that checkpoint callbacks always get moved to the end of the list, with preserved order. """ + checkpoint1 = ModelCheckpoint(tmpdir) + checkpoint2 = ModelCheckpoint(tmpdir) + early_stopping = EarlyStopping() + lr_monitor = LearningRateMonitor() + progress_bar = ProgressBar() + + # no model callbacks + model = Mock() + model.configure_callbacks.return_value = [] + trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) + cb_connector = CallbackConnector(trainer) + cb_connector._attach_model_callbacks(model, trainer) + assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] + + # with model-specific callbacks that substitute ones in Trainer + model = Mock() + model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2] + trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) + cb_connector = CallbackConnector(trainer) + cb_connector._attach_model_callbacks(model, trainer) + assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2] + + +class StatefulCallback0(Callback): + + def on_save_checkpoint(self, *args): + return {"content0": 0} + + +class StatefulCallback1(Callback): + + def on_save_checkpoint(self, *args): + return {"content1": 1} + + +def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): + """ Test that all callback states get saved even if the ModelCheckpoint is not given as last. """ + + callback0 = StatefulCallback0() + callback1 = StatefulCallback1() + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + limit_val_batches=1, + callbacks=[callback0, checkpoint_callback, callback1] + ) + trainer.fit(model) + + ckpt = torch.load(str(tmpdir / "all_states.ckpt")) + state0 = ckpt["callbacks"][type(callback0)] + state1 = ckpt["callbacks"][type(callback1)] + assert "content0" in state0 and state0["content0"] == 0 + assert "content1" in state1 and state1["content1"] == 1 + assert type(checkpoint_callback) in ckpt["callbacks"] + + +def test_attach_model_callbacks(): + """ Test that the callbacks defined in the model and through Trainer get merged correctly. """ + + def assert_composition(trainer_callbacks, model_callbacks, expected): + model = Mock() + model.configure_callbacks.return_value = model_callbacks + trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) + cb_connector = CallbackConnector(trainer) + cb_connector._attach_model_callbacks(model, trainer) + assert trainer.callbacks == expected + + early_stopping = EarlyStopping() + progress_bar = ProgressBar() + lr_monitor = LearningRateMonitor() + grad_accumulation = GradientAccumulationScheduler({1: 1}) + + # no callbacks + assert_composition(trainer_callbacks=[], model_callbacks=[], expected=[]) + + # callbacks of different types + assert_composition( + trainer_callbacks=[early_stopping], model_callbacks=[progress_bar], expected=[early_stopping, progress_bar] + ) + + # same callback type twice, different instance + assert_composition( + trainer_callbacks=[progress_bar, EarlyStopping()], + model_callbacks=[early_stopping], + expected=[progress_bar, early_stopping] + ) + + # multiple callbacks of the same type in trainer + assert_composition( + trainer_callbacks=[LearningRateMonitor(), + EarlyStopping(), + LearningRateMonitor(), + EarlyStopping()], + model_callbacks=[early_stopping, lr_monitor], + expected=[early_stopping, lr_monitor] + ) + + # multiple callbacks of the same type, in both trainer and model + assert_composition( + trainer_callbacks=[ + LearningRateMonitor(), progress_bar, + EarlyStopping(), + LearningRateMonitor(), + EarlyStopping() + ], + model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping], + expected=[progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping] + ) + + +def test_attach_model_callbacks_override_info(caplog): + """ Test that the logs contain the info about overriding callbacks returned by configure_callbacks. """ + model = Mock() + model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()] + trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) + cb_connector = CallbackConnector(trainer) + with caplog.at_level(logging.INFO): + cb_connector._attach_model_callbacks(model, trainer) + + assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text diff --git a/tests/trainer/data_flow/__init__.py b/tests/trainer/data_flow/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/data_flow/test_eval_loop_flow_1_0.py b/tests/trainer/data_flow/test_eval_loop_flow_1_0.py new file mode 100644 index 00000000000000..a6de667bf8c196 --- /dev/null +++ b/tests/trainer/data_flow/test_eval_loop_flow_1_0.py @@ -0,0 +1,245 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule +from tests.helpers.deterministic_model import DeterministicModel + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__eval_step__flow(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + if batch_idx == 0: + out = ['1', 2, torch.tensor(2)] + if batch_idx > 0: + out = {'something': 'random'} + return out + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__eval_step__eval_step_end__flow(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + if batch_idx == 0: + out = ['1', 2, torch.tensor(2)] + if batch_idx > 0: + out = {'something': 'random'} + self.last_out = out + return out + + def validation_step_end(self, out): + self.validation_step_end_called = True + assert self.last_out == out + return out + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.validation_step_called + assert model.validation_step_end_called + assert not model.validation_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__eval_step__epoch_end__flow(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + if batch_idx == 0: + out = ['1', 2, torch.tensor(2)] + self.out_a = out + if batch_idx > 0: + out = {'something': 'random'} + self.out_b = out + return out + + def validation_epoch_end(self, outputs): + self.validation_epoch_end_called = True + assert len(outputs) == 2 + + out_a = outputs[0] + out_b = outputs[1] + + assert out_a == self.out_a + assert out_b == self.out_b + + return {'no returns needed'} + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.validation_step_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + + with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"): + trainer.fit(model) + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert model.validation_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__validation_step__step_end__epoch_end__flow(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + if batch_idx == 0: + out = ['1', 2, torch.tensor(2)] + self.out_a = out + if batch_idx > 0: + out = {'something': 'random'} + self.out_b = out + self.last_out = out + return out + + def validation_step_end(self, out): + self.validation_step_end_called = True + assert self.last_out == out + return out + + def validation_epoch_end(self, outputs): + self.validation_epoch_end_called = True + assert len(outputs) == 2 + + out_a = outputs[0] + out_b = outputs[1] + + assert out_a == self.out_a + assert out_b == self.out_b + + return {'no returns needed'} + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + + with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"): + trainer.fit(model) + + # make sure correct steps were called + assert model.validation_step_called + assert model.validation_step_end_called + assert model.validation_epoch_end_called diff --git a/tests/trainer/data_flow/test_flow_warnings.py b/tests/trainer/data_flow/test_flow_warnings.py new file mode 100644 index 00000000000000..d3280b8eb6a86f --- /dev/null +++ b/tests/trainer/data_flow/test_flow_warnings.py @@ -0,0 +1,51 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from unittest import mock + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + return acc + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_no_depre_without_epoch_end(tmpdir): + """ + Tests that only training_step can be used + """ + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + + with warnings.catch_warnings(record=True) as w: + trainer.fit(model) + + for msg in w: + assert 'should not return anything ' not in str(msg) diff --git a/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py new file mode 100644 index 00000000000000..f38dda9c530caf --- /dev/null +++ b/tests/trainer/data_flow/test_train_loop_flow_dict_1_0.py @@ -0,0 +1,207 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +import os +from unittest import mock + +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule +from tests.helpers.deterministic_model import DeterministicModel + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__flow_dict(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + return {'loss': acc, 'random_things': [1, 'a', torch.tensor(2)]} + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__tr_step_end__flow_dict(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + self.out = {'loss': acc, 'random_things': [1, 'a', torch.tensor(2)]} + return self.out + + def training_step_end(self, tr_step_output): + assert tr_step_output == self.out + assert self.count_num_graphs(tr_step_output) == 1 + self.training_step_end_called = True + return tr_step_output + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__epoch_end__flow_dict(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + + self.training_step_called = True + out = {'loss': acc, 'random_things': [1, 'a', torch.tensor(2)]} + return out + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + + # verify we saw the current num of batches + assert len(outputs) == 2 + + for b in outputs: + assert isinstance(b, dict) + assert self.count_num_graphs(b) == 0 + assert {'random_things', 'loss'} == set(b.keys()) + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__step_end__epoch_end__flow_dict(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + + self.training_step_called = True + self.out = {'loss': acc, 'random_things': [1, 'a', torch.tensor(2)]} + return self.out + + def training_step_end(self, tr_step_output): + assert tr_step_output == self.out + assert self.count_num_graphs(tr_step_output) == 1 + self.training_step_end_called = True + return tr_step_output + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + + # verify we saw the current num of batches + assert len(outputs) == 2 + + for b in outputs: + assert isinstance(b, dict) + assert self.count_num_graphs(b) == 0 + assert {'random_things', 'loss'} == set(b.keys()) + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called diff --git a/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py new file mode 100644 index 00000000000000..d5a4da79942ed6 --- /dev/null +++ b/tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py @@ -0,0 +1,289 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule +from tests.helpers.boring_model import BoringModel +from tests.helpers.deterministic_model import DeterministicModel +from tests.helpers.utils import no_warning_call + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__flow_scalar(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + return acc + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__tr_step_end__flow_scalar(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.training_step_called = True + self.out = acc + return acc + + def training_step_end(self, tr_step_output): + assert self.out == tr_step_output + assert self.count_num_graphs({'loss': tr_step_output}) == 1 + self.training_step_end_called = True + return tr_step_output + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__epoch_end__flow_scalar(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + + self.training_step_called = True + return acc + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + + # verify we saw the current num of batches + assert len(outputs) == 2 + + for b in outputs: + # time = 1 + assert len(b) == 1 + assert 'loss' in b + assert isinstance(b, dict) + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + + self.training_step_called = True + return acc + + def training_step_end(self, tr_step_output): + assert isinstance(tr_step_output, torch.Tensor) + assert self.count_num_graphs({'loss': tr_step_output}) == 1 + self.training_step_end_called = True + return tr_step_output + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + + # verify we saw the current num of batches + assert len(outputs) == 2 + + for b in outputs: + # time = 1 + assert len(b) == 1 + assert 'loss' in b + assert isinstance(b, dict) + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + +def test_train_step_no_return(tmpdir): + """ + Tests that only training_step raises a warning when + nothing is returned in case of automatic_optimization + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.training_step_called = True + loss = self.step(batch[0]) + self.log('a', loss, on_step=True, on_epoch=True) + + def training_epoch_end(self, outputs) -> None: + assert len(outputs) == 0 + + def validation_step(self, batch, batch_idx): + self.validation_step_called = True + + def validation_epoch_end(self, outputs): + assert len(outputs) == 0 + + model = TestModel() + trainer_args = dict( + default_root_dir=tmpdir, + fast_dev_run=2, + ) + + trainer = Trainer(**trainer_args) + + with pytest.warns(UserWarning, match=r'training_step returned None .*'): + trainer.fit(model) + + assert model.training_step_called + assert model.validation_step_called + + model = TestModel() + model.automatic_optimization = False + trainer = Trainer(**trainer_args) + + with no_warning_call(UserWarning, match=r'training_step returned None .*'): + trainer.fit(model) + + +def test_training_step_no_return_when_even(tmpdir): + """ + Tests correctness when some training steps have been skipped + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.training_step_called = True + loss = self.step(batch[0]) + self.log('a', loss, on_step=True, on_epoch=True) + return loss if batch_idx % 2 else None + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=4, + limit_val_batches=1, + max_epochs=4, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + + with pytest.warns(UserWarning, match=r'.*training_step returned None.*'): + trainer.fit(model) + + # manually check a few batches + for batch_idx, batch in enumerate(model.train_dataloader()): + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + if not batch_idx % 2: + assert out.training_step_output_for_epoch_end == [[]] + assert out.signal == 0 diff --git a/tests/trainer/dynamic_args/__init__.py b/tests/trainer/dynamic_args/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py new file mode 100644 index 00000000000000..9a532cfe1ce472 --- /dev/null +++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -0,0 +1,174 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch.utils.data import Dataset + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +class RandomDatasetA(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return torch.zeros(1) + + def __len__(self): + return self.len + + +class RandomDatasetB(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return torch.ones(1) + + def __len__(self): + return self.len + + +def test_multiple_eval_dataloaders_tuple(tmpdir): + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx == 0: + assert batch.sum() == 0 + elif dataloader_idx == 1: + assert batch.sum() == 11 + else: + raise Exception('should only have two dataloaders') + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) + dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) + return [dl1, dl2] + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_multiple_eval_dataloaders_list(tmpdir): + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx == 0: + assert batch.sum() == 0 + elif dataloader_idx == 1: + assert batch.sum() == 11 + else: + raise Exception('should only have two dataloaders') + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) + dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) + return dl1, dl2 + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_multiple_optimizers_multiple_dataloaders(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def on_train_epoch_start(self) -> None: + self.opt_0_seen = False + self.opt_1_seen = False + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + self.opt_0_seen = True + elif optimizer_idx == 1: + self.opt_1_seen = True + else: + raise Exception('should only have two optimizers') + + self.training_step_called = True + loss = self.step(batch[0]) + return loss + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx == 0: + assert batch.sum() == 0 + elif dataloader_idx == 1: + assert batch.sum() == 11 + else: + raise Exception('should only have two dataloaders') + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) + dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) + return dl1, dl2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + assert model.opt_0_seen + assert model.opt_1_seen diff --git a/tests/trainer/flags/__init__.py b/tests/trainer/flags/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py new file mode 100644 index 00000000000000..65b251a6633b5d --- /dev/null +++ b/tests/trainer/flags/test_env_vars.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +from pytorch_lightning import Trainer + + +def test_passing_no_env_variables(): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.logger is not None + assert trainer.max_steps is None + trainer = Trainer(False, max_steps=42) + assert trainer.logger is None + assert trainer.max_steps == 42 + + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_only(): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.logger is None + assert trainer.max_steps == 7 + + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_defaults(): + """Testing overwriting trainer arguments """ + trainer = Trainer(False, max_steps=42) + assert trainer.logger is None + assert trainer.max_steps == 42 + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.gpus == 2 + trainer = Trainer(gpus=1) + assert trainer.gpus == 1 diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py new file mode 100644 index 00000000000000..9160d8d0f3d617 --- /dev/null +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -0,0 +1,129 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel + + +@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder']) +def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): + """ Test that tuner algorithms are skipped if fast dev run is enabled """ + + model = BoringModel() + model.lr = 0.1 # avoid no-lr-found exception + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + auto_scale_batch_size=(tuner_alg == 'batch size scaler'), + auto_lr_find=(tuner_alg == 'learning rate finder'), + fast_dev_run=True + ) + expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) + + +@pytest.mark.parametrize('fast_dev_run', [1, 4]) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): + """ + Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run + """ + + class FastDevRunModel(BoringModel): + + def __init__(self): + super().__init__() + self.training_step_call_count = 0 + self.training_epoch_end_call_count = 0 + self.validation_step_call_count = 0 + self.validation_epoch_end_call_count = 0 + self.test_step_call_count = 0 + + def training_step(self, batch, batch_idx): + self.log('some_metric', torch.tensor(7.)) + self.logger.experiment.dummy_log('some_distribution', torch.randn(7) + batch_idx) + self.training_step_call_count += 1 + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs): + self.training_epoch_end_call_count += 1 + super().training_epoch_end(outputs) + + def validation_step(self, batch, batch_idx): + self.validation_step_call_count += 1 + return super().validation_step(batch, batch_idx) + + def validation_epoch_end(self, outputs): + self.validation_epoch_end_call_count += 1 + super().validation_epoch_end(outputs) + + def test_step(self, batch, batch_idx): + self.test_step_call_count += 1 + return super().test_step(batch, batch_idx) + + checkpoint_callback = ModelCheckpoint() + early_stopping_callback = EarlyStopping() + trainer_config = dict( + default_root_dir=tmpdir, + fast_dev_run=fast_dev_run, + val_check_interval=2, + logger=True, + log_every_n_steps=1, + callbacks=[checkpoint_callback, early_stopping_callback], + ) + + def _make_fast_dev_run_assertions(trainer, model): + # check the call count for train/val/test step/epoch + assert model.training_step_call_count == fast_dev_run + assert model.training_epoch_end_call_count == 1 + assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run + assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1 + assert model.test_step_call_count == fast_dev_run + + # check trainer arguments + assert trainer.max_steps == fast_dev_run + assert trainer.num_sanity_val_steps == 0 + assert trainer.max_epochs == 1 + assert trainer.val_check_interval == 1.0 + assert trainer.check_val_every_n_epoch == 1 + + # there should be no logger with fast_dev_run + assert isinstance(trainer.logger, DummyLogger) + assert len(trainer.dev_debugger.logged_metrics) == fast_dev_run + + # checkpoint callback should not have been called with fast_dev_run + assert trainer.checkpoint_callback == checkpoint_callback + assert not os.path.exists(checkpoint_callback.dirpath) + assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + + # early stopping should not have been called with fast_dev_run + assert trainer.early_stopping_callback == early_stopping_callback + assert len(trainer.dev_debugger.early_stopping_history) == 0 + + train_val_step_model = FastDevRunModel() + trainer = Trainer(**trainer_config) + trainer.fit(train_val_step_model) + trainer.test(ckpt_path=None) + + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + _make_fast_dev_run_assertions(trainer, train_val_step_model) + + # ----------------------- + # also called once with no val step + # ----------------------- + train_step_only_model = FastDevRunModel() + train_step_only_model.validation_step = None + + trainer = Trainer(**trainer_config) + trainer.fit(train_step_only_model) + trainer.test(ckpt_path=None) + + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + _make_fast_dev_run_assertions(trainer, train_step_only_model) diff --git a/tests/trainer/flags/test_min_max_epochs.py b/tests/trainer/flags/test_min_max_epochs.py new file mode 100644 index 00000000000000..31b9a19960a145 --- /dev/null +++ b/tests/trainer/flags/test_min_max_epochs.py @@ -0,0 +1,39 @@ +import pytest + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +@pytest.mark.parametrize( + ["min_epochs", "max_epochs", "min_steps", "max_steps"], + [ + (None, 3, None, None), + (None, None, None, 20), + (None, 3, None, 20), + (None, None, 10, 20), + (1, 3, None, None), + (1, None, None, 20), + (None, 3, 10, None), + ], +) +def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_steps): + """ + Tests that max_steps can be used without max_epochs + """ + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + min_epochs=min_epochs, + max_epochs=max_epochs, + min_steps=min_steps, + max_steps=max_steps, + weights_summary=None, + ) + + result = trainer.fit(model) + assert result == 1, "Training did not complete" + + # check training stopped at max_epochs or max_steps + if trainer.max_steps and not trainer.max_epochs: + assert trainer.global_step == trainer.max_steps diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py new file mode 100644 index 00000000000000..ba11ccba7fc12c --- /dev/null +++ b/tests/trainer/flags/test_overfit_batches.py @@ -0,0 +1,69 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel, RandomDataset + + +def test_overfit_multiple_val_loaders(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx, dataloader_idx): + output = self.layer(batch[0]) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + pass + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64)) + dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64)) + return [dl1, dl2] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + overfit_batches=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + +@pytest.mark.parametrize('overfit', [1, 2, 0.1, 0.25, 1.0]) +def test_overfit_basic(tmpdir, overfit): + """ + Tests that only training_step can be used + """ + + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + overfit_batches=overfit, + weights_summary=None, + ) + + trainer.fit(model) diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py new file mode 100644 index 00000000000000..7f3e9f6287cd80 --- /dev/null +++ b/tests/trainer/flags/test_val_check_interval.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning.trainer import Trainer +from tests.helpers import BoringModel + + +@pytest.mark.parametrize('max_epochs', [1, 2, 3]) +@pytest.mark.parametrize('denominator', [1, 3, 4]) +def test_val_check_interval(tmpdir, max_epochs, denominator): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.train_epoch_calls = 0 + self.val_epoch_calls = 0 + + def on_train_epoch_start(self) -> None: + self.train_epoch_calls += 1 + + def on_validation_epoch_start(self) -> None: + if not self.trainer.sanity_checking: + self.val_epoch_calls += 1 + + model = TestModel() + trainer = Trainer( + max_epochs=max_epochs, + val_check_interval=1 / denominator, + logger=False, + ) + trainer.fit(model) + + assert model.train_epoch_calls == max_epochs + assert model.val_epoch_calls == max_epochs * denominator diff --git a/tests/trainer/legacy_deprecate_flow_log/__init__.py b/tests/trainer/legacy_deprecate_flow_log/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py new file mode 100644 index 00000000000000..0894acd5fe72d0 --- /dev/null +++ b/tests/trainer/legacy_deprecate_flow_log/test_eval_loop_dict_return.py @@ -0,0 +1,157 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict +""" +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule +from tests.helpers.deterministic_model import DeterministicModel + + +def test_validation_step_no_return(tmpdir): + """ + Test that val step can return nothing + """ + + class TestModel(DeterministicModel): + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.training_step = model.training_step__dict_return + model.validation_step = model.validation_step__no_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + with pytest.warns(RuntimeWarning, match="the running stage is set to None"): + out, eval_results = trainer.run_evaluation() + assert len(out) == 1 + assert len(eval_results) == 0 + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_validation_step_scalar_return(tmpdir): + """ + Test that val step can return a scalar + """ + model = DeterministicModel() + model.training_step = model.training_step__dict_return + model.validation_step = model.validation_step__scalar_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + out, eval_results = trainer.run_evaluation() + assert len(out) == 1 + assert len(eval_results) == 2 + assert eval_results[0] == 171 and eval_results[1] == 171 + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_validation_step_arbitrary_dict_return(tmpdir): + """ + Test that val step can return an arbitrary dict + """ + model = DeterministicModel() + model.training_step = model.training_step__dict_return + model.validation_step = model.validation_step__dummy_dict_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation() + assert len(callback_metrics) == 1 + assert len(eval_results) == 2 + assert eval_results[0]['some'] == 171 + assert eval_results[1]['some'] == 171 + + assert eval_results[0]['value'] == 'a' + assert eval_results[1]['value'] == 'a' + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_val_step_step_end_no_return(tmpdir): + """ + Test that val step + val step end work (with no return in val step end) + """ + + model = DeterministicModel() + model.training_step = model.training_step__dict_return + model.validation_step = model.validation_step__dict_return + model.validation_step_end = model.validation_step_end__no_return + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation() + assert len(callback_metrics) == 1 + assert len(eval_results) == 0 + + # make sure correct steps were called + assert model.validation_step_called + assert model.validation_step_end_called + assert not model.validation_epoch_end_called diff --git a/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py b/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py new file mode 100644 index 00000000000000..b88788f3cc28b0 --- /dev/null +++ b/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_dict_return.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict +""" + +from pytorch_lightning import Trainer +from tests.helpers.deterministic_model import DeterministicModel + + +def test_training_step_dict(tmpdir): + """ + Tests that only training_step can be used + """ + model = DeterministicModel() + model.training_step = model.training_step__dict_return + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + + assert out.signal == 0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 12.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 7.0 + + train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + + train_step_out = train_step_out[0][0] + pbar_metrics = train_step_out['progress_bar'] + assert 'log' in train_step_out + assert 'progress_bar' in train_step_out + assert train_step_out['train_step_test'] == 549 + assert pbar_metrics['pbar_acc1'] == 17.0 + assert pbar_metrics['pbar_acc2'] == 19.0 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.train_loop.training_step_and_backward( + batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + ) + assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) + + +def training_step_with_step_end(tmpdir): + """ + Checks train_step + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step__dict_return + model.training_step_end = model.training_step_end__dict + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 14.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 9.0 + + train_step_end_out = out.training_step_output_for_epoch_end + pbar_metrics = train_step_end_out['progress_bar'] + assert 'train_step_end' in train_step_end_out + assert pbar_metrics['pbar_acc1'] == 19.0 + assert pbar_metrics['pbar_acc2'] == 21.0 diff --git a/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_scalar_return.py new file mode 100644 index 00000000000000..5836251f2c92a3 --- /dev/null +++ b/tests/trainer/legacy_deprecate_flow_log/test_trainer_steps_scalar_return.py @@ -0,0 +1,249 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a scalar +""" +import os +from unittest import mock + +import torch + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel +from tests.helpers.deterministic_model import DeterministicModel +from tests.helpers.runif import RunIf + + +def test_training_step_scalar(tmpdir): + """ + Tests that only training_step that returns a single scalar can be used + """ + model = DeterministicModel() + model.training_step = model.training_step__scalar_return + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] + assert isinstance(train_step_out['minimize'], torch.Tensor) + assert train_step_out['minimize'].item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.train_loop.training_step_and_backward( + batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + ) + assert opt_closure_result['loss'].item() == 171 + + +def training_step_scalar_with_step_end(tmpdir): + """ + Checks train_step with scalar only + training_step_end + """ + model = DeterministicModel() + model.training_step = model.training_step__scalar_return + model.training_step_end = model.training_step_end__scalar + model.val_dataloader = None + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, weights_summary=None) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] + assert isinstance(train_step_out, torch.Tensor) + assert train_step_out.item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.train_loop.training_step_and_backward( + batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + ) + assert opt_closure_result['loss'].item() == 171 + + +def test_full_training_loop_scalar(tmpdir): + """ + Checks train_step + training_step_end + training_epoch_end + (all with scalar return from train_step) + """ + + model = DeterministicModel() + model.training_step = model.training_step__scalar_return + model.training_step_end = model.training_step_end__scalar + model.training_epoch_end = model.training_epoch_end__scalar + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert len(trainer.logger_connector.callback_metrics) == 0 + assert len(trainer.logger_connector.progress_bar_metrics) == 0 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] + assert isinstance(train_step_out['minimize'], torch.Tensor) + assert train_step_out['minimize'].item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.train_loop.training_step_and_backward( + batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + ) + assert opt_closure_result['loss'].item() == 171 + + +def test_train_step_epoch_end_scalar(tmpdir): + """ + Checks train_step + training_epoch_end (NO training_step_end) + (with scalar return) + """ + + model = DeterministicModel() + model.training_step = model.training_step__scalar_return + model.training_step_end = None + model.training_epoch_end = model.training_epoch_end__scalar + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # assert epoch end metrics were added + assert len(trainer.logger_connector.callback_metrics) == 0 + assert len(trainer.logger_connector.progress_bar_metrics) == 0 + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 + assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) + + train_step_out = out.training_step_output_for_epoch_end + assert len(train_step_out) == 1 + train_step_out = train_step_out[0][0] + assert isinstance(train_step_out['minimize'], torch.Tensor) + assert train_step_out['minimize'].item() == 171 + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.train_loop.training_step_and_backward( + batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens + ) + assert opt_closure_result['loss'].item() == 171 + + +class DPPReduceMeanPbarModel(BoringModel): + + logged = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + loss /= loss.clone().detach() + self.log('self_log', loss, prog_bar=True, sync_dist=True) + return {"loss": loss, "progress_bar": {"loss_2": loss}} + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=2) +def test_dpp_reduce_mean_pbar(tmpdir): + + model = DPPReduceMeanPbarModel() + model.training_step_end = None + model.training_epoch_end = None + + distributed_backend = "ddp_spawn" + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=10, + limit_test_batches=2, + limit_val_batches=2, + accelerator=distributed_backend, + gpus=2, + precision=32, + ) + + trainer.fit(model) + + # TODO: Move this test to DDP. pbar_added_metrics is empty with ddp_spawn for some reasons + + pbar_added_metrics = trainer.dev_debugger.pbar_added_metrics + is_in = False + for pbar_metrics in pbar_added_metrics: + if 'loss_2' in pbar_metrics: + is_in = True + assert pbar_metrics["loss_2"].item() == 1 + if distributed_backend == "ddp": + assert is_in is True diff --git a/tests/trainer/logging_/__init__.py b/tests/trainer/logging_/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py new file mode 100644 index 00000000000000..5832f387cc63d6 --- /dev/null +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -0,0 +1,105 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock +from unittest.mock import Mock + +from pytorch_lightning import Callback, Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class TestModel(BoringModel): + + def on_pretrain_routine_end(self) -> None: + with mock.patch('pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics') as m: + self.trainer.logger_connector.log_metrics({'a': 2}, {}) + logged_times = m.call_count + expected = int(self.trainer.is_global_zero) + msg = f'actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}' + assert logged_times == expected, msg + + +@RunIf(skip_windows=True) +def test_global_zero_only_logging_ddp_cpu(tmpdir): + """ + Makes sure logging only happens from root zero + """ + model = TestModel() + model.training_epoch_end = None + trainer = Trainer( + accelerator='ddp_cpu', + num_processes=2, + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + +@RunIf(min_gpus=2) +def test_global_zero_only_logging_ddp_spawn(tmpdir): + """ + Makes sure logging only happens from root zero + """ + model = TestModel() + model.training_epoch_end = None + trainer = Trainer( + accelerator='ddp_spawn', + gpus=2, + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + +def test_first_logger_call_in_subprocess(tmpdir): + """ + Test that the Trainer does not call the logger too early. Only when the worker processes are initialized + do we have access to the rank and know which one is the main process. + """ + + class LoggerCallsObserver(Callback): + + def on_fit_start(self, trainer, pl_module): + # this hook is executed directly before Trainer.pre_dispatch + # logger should not write any logs until this point + assert not trainer.logger.method_calls + assert not os.listdir(trainer.logger.save_dir) + + def on_train_start(self, trainer, pl_module): + assert trainer.logger.method_call + trainer.logger.log_hyperparams.assert_called_once() + trainer.logger.log_graph.assert_called_once() + + logger = Mock() + logger.version = "0" + logger.name = "name" + logger.save_dir = tmpdir + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + logger=logger, + callbacks=[LoggerCallsObserver()] + ) + trainer.fit(model) diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py new file mode 100644 index 00000000000000..32bff96baf5667 --- /dev/null +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -0,0 +1,927 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +import collections +import itertools +import os +from unittest import mock +from unittest.mock import call + +import numpy as np +import pytest +import torch + +from pytorch_lightning import callbacks, seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers import TensorBoardLogger +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.deterministic_model import DeterministicModel + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__validation_step__log(tmpdir): + """ + Tests that validation_step can log + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('a', acc, on_step=True, on_epoch=True) + self.log('a2', 2) + + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('b', acc, on_step=True, on_epoch=True) + self.training_step_called = True + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + expected_logged_metrics = { + 'a2', + 'a_step', + 'a_epoch', + 'b_step/epoch_0', + 'b_step/epoch_1', + 'b_epoch', + 'epoch', + } + logged_metrics = set(trainer.logged_metrics.keys()) + assert expected_logged_metrics == logged_metrics + + # we don't want to enable val metrics during steps because it is not something that users should do + # on purpose DO NOT allow step_b... it's silly to monitor val step metrics + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'} + assert expected_cb_metrics == callback_metrics + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__validation_step__step_end__epoch_end__log(tmpdir): + """ + Tests that validation_step can log + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('a', acc) + self.log('b', acc, on_step=True, on_epoch=True) + self.training_step_called = True + return acc + + def validation_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('c', acc) + self.log('d', acc, on_step=True, on_epoch=True) + self.validation_step_called = True + return acc + + def validation_step_end(self, acc): + self.validation_step_end_called = True + # self.log('e', acc) + # self.log('f', acc, on_step=True, on_epoch=True) + return ['random_thing'] + + def validation_epoch_end(self, outputs): + self.log('g', torch.tensor(2, device=self.device), on_epoch=True) + self.validation_epoch_end_called = True + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'epoch', + 'a', + 'b_step', + 'b_epoch', + 'c', + 'd_step/epoch_0', + 'd_step/epoch_1', + 'd_epoch', + # 'e', + # 'f_step/epoch_0', + # 'f_step/epoch_1', + # 'f_epoch', + 'g', + } + assert expected_logged_metrics == logged_metrics + + progress_bar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = set() + assert expected_pbar_metrics == progress_bar_metrics + + # we don't want to enable val metrics during steps because it is not something that users should do + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_cb_metrics = {'a', 'b', 'b_epoch', 'c', 'd', 'd_epoch', 'g', 'b_step'} + # expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'b_epoch', 'd_epoch', 'f_epoch', 'f', 'g', 'b_step'} + assert expected_cb_metrics == callback_metrics + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) +def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def validation_epoch_end(self, outputs): + self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True) + self.log('d/e/f', 2) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + max_epochs=max_epochs, + log_every_n_steps=log_interval, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'c', + 'd/e/f', + 'epoch', + } + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = {'c'} + assert pbar_metrics == expected_pbar_metrics + + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.remove('epoch') + assert callback_metrics == expected_callback_metrics + + # assert the loggers received the expected number + assert len(trainer.dev_debugger.logged_metrics) == max_epochs + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_eval_float_logging(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('a', 12.0) + return {"x": loss} + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'a', + 'epoch', + } + assert logged_metrics == expected_logged_metrics + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_eval_logging_auto_reduce(tmpdir): + """ + Tests that only training_step can be used + """ + seed_everything(1234) + + class TestModel(BoringModel): + + def on_pretrain_routine_end(self) -> None: + self.seen_vals = [] + self.manual_epoch_end_mean = None + + def on_validation_epoch_start(self) -> None: + self.seen_vals = [] + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.seen_vals.append(loss) + self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + for passed_in, manually_tracked in zip(outputs, self.seen_vals): + assert passed_in['x'] == manually_tracked + self.manual_epoch_end_mean = torch.stack([x['x'] for x in outputs]).mean() + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=3, + limit_val_batches=3, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + ) + trainer.fit(model) + + # make sure all the metrics are available for callbacks + manual_mean = model.manual_epoch_end_mean + callback_metrics = set(trainer.callback_metrics.keys()) + assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'} + + # make sure values are correct + assert trainer.logged_metrics['val_loss_epoch'] == manual_mean + assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step/epoch_0'] + + # make sure correct values were logged + logged_val = trainer.dev_debugger.logged_metrics + + # 3 val batches + assert logged_val[0]['val_loss_step/epoch_0'] == model.seen_vals[0] + assert logged_val[1]['val_loss_step/epoch_0'] == model.seen_vals[1] + assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[2] + + # epoch mean + assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean + + # only those logged + assert len(logged_val) == 4 + + +@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) +def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): + """ + Tests that only test_epoch_end can be used to log, and we return them in the results. + """ + + class TestModel(BoringModel): + + def test_epoch_end(self, outputs): + self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True) + self.log('d/e/f', 2) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + max_epochs=max_epochs, + log_every_n_steps=log_interval, + weights_summary=None, + ) + trainer.fit(model) + results = trainer.test(model) + + expected_result_metrics = { + 'c': torch.tensor(2), + 'd/e/f': 2, + } + for result in results: + assert result == expected_result_metrics + + +def test_monitor_val_epoch_end(tmpdir): + epoch_min_loss_override = 0 + model = BoringModel() + checkpoint_callback = callbacks.ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="avg_val_loss") + trainer = Trainer( + max_epochs=epoch_min_loss_override + 2, + logger=False, + callbacks=[checkpoint_callback], + ) + trainer.fit(model) + + +def test_multi_dataloaders_add_suffix_properly(tmpdir): + + class TestModel(BoringModel): + + def test_step(self, batch, *args): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("test_loss", loss, on_step=True, on_epoch=True) + + def test_dataloader(self): + return [ + torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64)) + ] + + model = TestModel() + model.test_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + results = trainer.test(model) + + assert {"test_loss/dataloader_idx_0", "test_loss_epoch/dataloader_idx_0"} == set(results[0]) + assert {"test_loss/dataloader_idx_1", "test_loss_epoch/dataloader_idx_1"} == set(results[1]) + + +def test_single_dataloader_no_suffix_added(tmpdir): + + class TestModel(BoringModel): + + def test_step(self, batch, *args): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("test_loss", loss, on_step=True, on_epoch=True) + + model = TestModel() + model.test_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=5, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + results = trainer.test(model) + + assert len(results) == 1 + assert {"test_loss", "test_loss_epoch"} == set(results[0]) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_log_works_in_val_callback(tmpdir): + """ + Tests that log can be called within callback + """ + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + self.funcs_called_count[func_name] += 1 + product = [on_steps, on_epochs, prob_bars] + for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*product))): + # run logging + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log( + custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar + ) + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": on_step and on_epoch, + "func_name": func_name + } + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name + } + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name + } + + def on_validation_start(self, trainer, pl_module): + self.make_logging( + pl_module, + 'on_validation_start', + 1, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + + def on_epoch_start(self, trainer, pl_module): + if trainer.validating: + self.make_logging( + pl_module, + 'on_epoch_start', + 2, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + + def on_validation_epoch_start(self, trainer, pl_module): + self.make_logging( + pl_module, + 'on_validation_epoch_start', + 3, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + + def on_batch_end(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging( + pl_module, + 'on_validation_batch_end', + 7, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_validation_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + if trainer.validating: + self.make_logging( + pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + ) + + def on_validation_epoch_end(self, trainer, pl_module): + self.make_logging( + pl_module, + 'on_validation_epoch_end', + 9, + on_steps=[False], + on_epochs=self.choices, + prob_bars=self.choices + ) + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('val_loss', loss) + + max_epochs = 1 + model = TestModel() + model.validation_epoch_end = None + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=4, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback], + ) + trainer.fit(model) + + assert test_callback.funcs_called_count["on_epoch_start"] == 1 + # assert test_callback.funcs_called_count["on_batch_start"] == 1 + assert test_callback.funcs_called_count["on_batch_end"] == 1 + assert test_callback.funcs_called_count["on_validation_start"] == 1 + assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 + # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 + assert test_callback.funcs_called_count["on_epoch_end"] == 1 + assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 + assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + trainer.callback_metrics.pop("val_loss") + for func_name, output_value in trainer.callback_metrics.items(): + # not sure how to handle this now + if "epoch_0" in func_name: + func_name = '/'.join(func_name.split('/')[:-1]) + continue + + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_log_works_in_test_callback(tmpdir): + """ + Tests that log can be called within callback + """ + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): + original_func_name = func_name[:] + self.funcs_called_count[original_func_name] += 1 + product = [on_steps, on_epochs, prob_bars] + for idx, t in enumerate(list(itertools.product(*product))): + # run logging + func_name = original_func_name[:] + on_step, on_epoch, prog_bar = t + custom_func_name = f"{func_idx}_{idx}_{func_name}" + + pl_module.log( + custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar + ) + + num_dl_ext = '' + if pl_module._current_dataloader_idx is not None: + dl_idx = pl_module._current_dataloader_idx + num_dl_ext = f"/dataloader_idx_{dl_idx}" + func_name += num_dl_ext + + # catch information for verification + self.callback_funcs_called[func_name].append([self.count * func_idx]) + self.funcs_attr[custom_func_name + num_dl_ext] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": on_step and on_epoch, + "func_name": func_name + } + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step" + num_dl_ext] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name + } + + self.funcs_attr[f"{custom_func_name}_epoch" + num_dl_ext] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name + } + + def on_test_start(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + + def on_test_epoch_start(self, trainer, pl_module): + self.make_logging( + pl_module, + 'on_test_epoch_start', + 3, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging( + pl_module, + 'on_test_batch_end', + 5, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_test_epoch_end(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + ) + + max_epochs = 2 + num_dataloaders = 2 + + class TestModel(BoringModel): + + manual_mean = collections.defaultdict(list) + + def test_step(self, batch, batch_idx, dataloader_idx=None): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('test_loss', loss) + self.manual_mean[str(dataloader_idx)].append(loss) + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] + + model = TestModel() + model.test_epoch_end = None + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=0, + limit_test_batches=2, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback], + ) + trainer.test(model) + + assert test_callback.funcs_called_count["on_test_start"] == 1 + assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 + assert test_callback.funcs_called_count["on_test_batch_end"] == 4 + assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + # Apply mean on values + if func_attr["on_epoch"] and not func_attr["on_step"]: + expected_output = np.mean(original_values) + else: + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + assert "debug_epoch" in trainer.callback_metrics + trainer.callback_metrics.pop("debug_epoch") + + for dl_idx in range(num_dataloaders): + key = f"test_loss/dataloader_idx_{dl_idx}" + assert key in trainer.callback_metrics + assert torch.stack(model.manual_mean[str(dl_idx)]).mean() == trainer.callback_metrics[key] + trainer.callback_metrics.pop(key) + + for func_name, output_value in trainer.callback_metrics.items(): + # not sure how to handle this now + if "epoch_1" in func_name: + func_name = '/'.join(func_name.split('/')[:-1]) + continue + + if torch.is_tensor(output_value): + output_value = output_value.item() + + # get func attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics + + +@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): + """ + This tests make sure we properly log_metrics to loggers + """ + + class ExtendedModel(BoringModel): + + val_losses = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.val_losses.append(loss) + self.log('valid_loss_0', loss, on_step=True, on_epoch=True) + self.log('valid_loss_1', loss, on_step=False, on_epoch=True) + self.log('valid_loss_2', loss, on_step=True, on_epoch=False) + self.log('valid_loss_3', loss, on_step=False, on_epoch=False) + return {"val_loss": loss} # not added to callback_metrics + + def test_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('test_loss', loss) + return {"y": loss} + + model = ExtendedModel() + model.validation_epoch_end = None + + # Initialize a trainer + trainer = Trainer( + default_root_dir=tmpdir, + logger=TensorBoardLogger(tmpdir), + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=2, + progress_bar_refresh_rate=1, + ) + + # Train the model ⚡ + trainer.fit(model) + + # hp_metric + 2 steps + epoch + 2 steps + epoch + expected_num_calls = 1 + 2 + 1 + 2 + 1 + + assert len(mock_log_metrics.mock_calls) == expected_num_calls + + assert mock_log_metrics.mock_calls[0] == call({'hp_metric': -1}, 0) + + def get_metrics_at_idx(idx): + mock_calls = list(mock_log_metrics.mock_calls) + if isinstance(mock_calls[idx].kwargs, dict): + return mock_calls[idx].kwargs["metrics"] + else: + return mock_calls[idx][2]["metrics"] + + expected = ['valid_loss_0_step/epoch_0', 'valid_loss_2/epoch_0', 'global_step'] + assert sorted(get_metrics_at_idx(1)) == sorted(expected) + assert sorted(get_metrics_at_idx(2)) == sorted(expected) + + expected = model.val_losses[2] + assert get_metrics_at_idx(1)["valid_loss_0_step/epoch_0"] == expected + expected = model.val_losses[3] + assert get_metrics_at_idx(2)["valid_loss_0_step/epoch_0"] == expected + + expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] + assert sorted(get_metrics_at_idx(3)) == sorted(expected) + + expected = torch.stack(model.val_losses[2:4]).mean() + assert get_metrics_at_idx(3)["valid_loss_1"] == expected + expected = ['valid_loss_0_step/epoch_1', 'valid_loss_2/epoch_1', 'global_step'] + + assert sorted(get_metrics_at_idx(4)) == sorted(expected) + assert sorted(get_metrics_at_idx(5)) == sorted(expected) + + expected = model.val_losses[4] + assert get_metrics_at_idx(4)["valid_loss_0_step/epoch_1"] == expected + expected = model.val_losses[5] + assert get_metrics_at_idx(5)["valid_loss_0_step/epoch_1"] == expected + + expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] + assert sorted(get_metrics_at_idx(6)) == sorted(expected) + + expected = torch.stack(model.val_losses[4:]).mean() + assert get_metrics_at_idx(6)["valid_loss_1"] == expected + + results = trainer.test(model) + expected_callback_metrics = { + 'train_loss', + 'valid_loss_0_epoch', + 'valid_loss_0', + 'debug_epoch', + 'valid_loss_1', + 'test_loss', + } + assert set(trainer.callback_metrics) == expected_callback_metrics + assert set(results[0]) == {'test_loss', 'debug_epoch'} diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py new file mode 100644 index 00000000000000..c4ba371e6c561e --- /dev/null +++ b/tests/trainer/logging_/test_logger_connector.py @@ -0,0 +1,563 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" +from copy import deepcopy +from typing import Any, Callable + +import pytest +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.runif import RunIf + + +def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: + + def decorator(func: Callable) -> Callable: + + def wrapper(self, *args, **kwargs) -> Any: + # Set information + self._current_fx_name = fx_name + self._current_hook_fx_name = hook_fx_name + self._results = Result() + + result = func(self, *args, **kwargs) + + # cache metrics + self.trainer.logger_connector.cache_logged_metrics() + return result + + return wrapper + + return decorator + + +def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch): + """ + Tests that LoggerConnector will properly capture logged information + and reduce them + """ + monkeypatch.setenv("PL_DEV_DEBUG", "1") + + class TestModel(BoringModel): + + train_losses = [] + + @decorator_with_arguments(fx_name="training_step") + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + + self.train_losses.append(loss) + + self.log("train_loss", loss, on_step=True, on_epoch=True) + + return {"loss": loss} + + def training_step_end(self, *_): + self.train_results = deepcopy(self.trainer.logger_connector.cached_results) + + model = TestModel() + model.training_epoch_end = None + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=4, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + train_results = model.train_results + + assert len(train_results(fx_name="training_step", dl_idx=0, opt_idx=0)) == 2 + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0, split_idx=0)["train_loss"] + assert generated == model.train_losses[0] + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=1, split_idx=0)["train_loss"] + assert generated == model.train_losses[1] + + assert train_results.has_reduced is not True + + train_results.has_batch_loop_finished = True + + assert train_results.has_reduced is True + + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item() + excepted = torch.stack(model.train_losses).mean().item() + assert generated == excepted + + +def test__logger_connector__epoch_result_store__train__tbptt(tmpdir): + """ + Tests that LoggerConnector will properly capture logged information with ttbt + and reduce them + """ + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class TestModel(BoringModel): + + train_losses = [] + + def __init__(self): + super().__init__() + self.test_hidden = None + self.layer = torch.nn.Linear(2, 2) + + @decorator_with_arguments(fx_name="training_step") + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + self.train_losses.append(loss) + + self.log('a', loss, on_epoch=True) + + return {'loss': loss, 'hiddens': self.test_hidden} + + def on_train_epoch_start(self) -> None: + self.test_hidden = None + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + def training_step_end(self, *_): + self.train_results = deepcopy(self.trainer.logger_connector.cached_results) + + model = TestModel() + model.training_epoch_end = None + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + limit_val_batches=0, + truncated_bptt_steps=truncated_bptt_steps, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + train_results = model.train_results + + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0) + assert len(generated) == len(model.train_losses) + + # assert reduction didn't happen yet + assert train_results.has_reduced is False + + # Launch reduction + train_results.has_batch_loop_finished = True + + # assert reduction did happen + assert train_results.has_reduced is True + + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['a_epoch'].item() + assert generated == torch.stack(model.train_losses).mean().item() + + +@pytest.mark.parametrize('num_dataloaders', [1, 2]) +def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders): + """ + Tests that LoggerConnector will properly capture logged information in multi dataloaders scenario + """ + monkeypatch.setenv("PL_DEV_DEBUG", "1") + + class TestModel(BoringModel): + test_losses = {dl_idx: [] for dl_idx in range(num_dataloaders)} + + @decorator_with_arguments(fx_name="test_step") + def test_step(self, batch, batch_idx, dl_idx=0): + output = self.layer(batch) + loss = self.loss(batch, output) + self.test_losses[dl_idx].append(loss) + self.log("test_loss", loss, on_step=True, on_epoch=True) + return {"test_loss": loss} + + def on_test_batch_end(self, *args, **kwargs): + # save objects as it will be reset at the end of epoch. + self.batch_results = deepcopy(self.trainer.logger_connector.cached_results) + + def on_test_epoch_end(self): + # save objects as it will be reset at the end of epoch. + self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results) + + def test_dataloader(self): + return [super().test_dataloader()] * num_dataloaders + + model = TestModel() + model.test_epoch_end = None + limit_test_batches = 4 + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=limit_test_batches, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.test(model) + + test_results = model.batch_results + + generated = test_results(fx_name="test_step") + assert len(generated) == num_dataloaders + + for dl_idx in range(num_dataloaders): + generated = test_results(fx_name="test_step", dl_idx=dl_idx) + assert len(generated) == limit_test_batches + + test_results = model.reduce_results + + for dl_idx in range(num_dataloaders): + expected = torch.stack(model.test_losses[dl_idx]).mean() + generated = test_results(fx_name="test_step", dl_idx=dl_idx, reduced=True)["test_loss_epoch"] + torch.testing.assert_allclose(generated, expected) + + +def test_call_back_validator(tmpdir): + + funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) + + callbacks_func = [ + 'on_after_backward', + 'on_batch_end', + 'on_batch_start', + 'on_before_accelerator_backend_setup', + 'on_before_zero_grad', + 'on_epoch_end', + 'on_epoch_start', + 'on_fit_end', + 'on_configure_sharded_model', + 'on_fit_start', + 'on_init_end', + 'on_init_start', + 'on_keyboard_interrupt', + 'on_load_checkpoint', + 'on_pretrain_routine_end', + 'on_pretrain_routine_start', + 'on_sanity_check_end', + 'on_sanity_check_start', + 'on_save_checkpoint', + 'on_test_batch_end', + 'on_test_batch_start', + 'on_test_end', + 'on_test_epoch_end', + 'on_test_epoch_start', + 'on_test_start', + 'on_train_batch_end', + 'on_train_batch_start', + 'on_train_end', + 'on_train_epoch_end', + 'on_train_epoch_start', + 'on_train_start', + 'on_validation_batch_end', + 'on_validation_batch_start', + 'on_validation_end', + 'on_validation_epoch_end', + 'on_validation_epoch_start', + 'on_validation_start', + 'setup', + 'teardown', + ] + + not_supported = [ + "on_before_accelerator_backend_setup", + "on_fit_end", + "on_fit_start", + "on_configure_sharded_model", + "on_init_end", + "on_init_start", + "on_keyboard_interrupt", + "on_load_checkpoint", + "on_pretrain_routine_end", + "on_pretrain_routine_start", + "on_sanity_check_end", + "on_sanity_check_start", + "on_save_checkpoint", + "on_test_end", + "on_train_end", + "on_validation_end", + "setup", + "teardown", + ] + + assert ( + funcs_name == sorted(callbacks_func) + ), """Detected new callback function. + Need to add its logging permission to CallbackHookNameValidator and update this test""" + + validator = CallbackHookNameValidator() + + for func_name in funcs_name: + # This summarizes where and what is currently possible to log using `self.log` + is_stage = "train" in func_name or "test" in func_name or "validation" in func_name + is_start = "start" in func_name or "batch" in func_name + on_step = is_stage and is_start + on_epoch = True + # creating allowed condition + allowed = ( + is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name + ) + allowed = ( + allowed and "pretrain" not in func_name + and func_name not in ["on_train_end", "on_test_end", "on_validation_end"] + ) + if allowed: + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, on_step=on_step, on_epoch=on_epoch) + if not is_start and is_stage: + with pytest.raises(MisconfigurationException, match="function supports only"): + validator.check_logging_in_callbacks( + current_hook_fx_name=func_name, on_step=True, on_epoch=on_epoch + ) + else: + assert func_name in not_supported + with pytest.raises(MisconfigurationException, match="function doesn't support"): + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, on_step=on_step, on_epoch=on_epoch) + + # should not fail + validator.check_logging_in_callbacks(current_hook_fx_name=None, on_step=None, on_epoch=None) + + +@RunIf(min_gpus=2) +def test_epoch_results_cache_dp(tmpdir): + + root_device = torch.device("cuda", 0) + + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + result = super().training_step(*args, **kwargs) + self.log("train_loss_epoch", result["loss"], on_step=False, on_epoch=True) + return result + + def training_step_end(self, training_step_outputs): # required for dp + loss = training_step_outputs["loss"].mean() + return loss + + def training_epoch_end(self, outputs): + assert all(out["loss"].device == root_device for out in outputs) + assert self.trainer.callback_metrics["train_loss_epoch"].device == root_device + + def validation_step(self, *args, **kwargs): + val_loss = torch.rand(1, device=torch.device("cuda", 1)) + self.log("val_loss_epoch", val_loss, on_step=False, on_epoch=True) + return val_loss + + def validation_epoch_end(self, outputs): + assert all(loss.device == root_device for loss in outputs) + assert self.trainer.callback_metrics["val_loss_epoch"].device == root_device + + def test_step(self, *args, **kwargs): + test_loss = torch.rand(1, device=torch.device("cuda", 1)) + self.log("test_loss_epoch", test_loss, on_step=False, on_epoch=True) + return test_loss + + def test_epoch_end(self, outputs): + assert all(loss.device == root_device for loss in outputs) + assert self.trainer.callback_metrics["test_loss_epoch"].device == root_device + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=4) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=4) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=4) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="dp", + gpus=2, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + trainer.test(model, ckpt_path=None) + + +@pytest.mark.parametrize('to_float', [False, True]) +def test_metrics_holder(to_float, tmpdir): + + device = "cuda" if torch.cuda.is_available() else "cpu" + preds = torch.tensor([[0.9, 0.1]], device=device) + + def is_float(value: Any) -> bool: + return isinstance(value, float) + + excepted_function = is_float if to_float else torch.is_tensor + targets = torch.tensor([1], device=device) + acc = Accuracy().to(device) + metric_holder = MetricsHolder(to_float=to_float) + metric_holder.update({ + "x": 1, + "y": torch.tensor(2), + "z": acc(preds, targets), + }) + metric_holder.convert(device) + metrics = metric_holder.metrics + assert excepted_function(metrics["x"]) + assert excepted_function(metrics["y"]) + assert excepted_function(metrics["z"]) + + +def test_metric_holder_raises(tmpdir): + """Check that an error is raised when trying to convert non-scalar tensors""" + + class TestModel(BoringModel): + + def validation_step(self, batch, *args, **kwargs): + output = self(batch) + self.log('test', output) + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + model = TestModel() + model.validation_epoch_end = None + model.test_epoch_end = None + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + + match = "The metric `test` does not contain a single element" + with pytest.raises(MisconfigurationException, match=match): + trainer.validate(model) + with pytest.raises(MisconfigurationException, match=match): + trainer.test(model) + + +def test_can_return_tensor_with_more_than_one_element(tmpdir): + """Ensure {validation,test}_step return values are not included as callback metrics. #6623""" + + class TestModel(BoringModel): + + def validation_step(self, batch, *args, **kwargs): + return {"val": torch.tensor([0, 1])} + + def validation_epoch_end(self, outputs): + # ensure validation step returns still appear here + assert len(outputs) == 2 + assert all(list(d) == ["val"] for d in outputs) # check keys + assert all(torch.equal(d["val"], torch.tensor([0, 1])) for d in outputs) # check values + + def test_step(self, batch, *args, **kwargs): + return {"test": torch.tensor([0, 1])} + + def test_epoch_end(self, outputs): + assert len(outputs) == 2 + assert all(list(d) == ["test"] for d in outputs) # check keys + assert all(torch.equal(d["test"], torch.tensor([0, 1])) for d in outputs) # check values + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2, progress_bar_refresh_rate=0) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + +def test_logging_to_progress_bar_with_reserved_key(tmpdir): + """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ + + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + self.log("loss", output["loss"], prog_bar=True) + return output + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): + trainer.fit(model) + + +@pytest.mark.parametrize("add_dataloader_idx", [False, True]) +def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx): + """ test that auto_add_dataloader_idx argument works """ + + class TestModel(BoringModel): + + def val_dataloader(self): + dl = super().val_dataloader() + return [dl, dl] + + def validation_step(self, *args, **kwargs): + output = super().validation_step(*args[:-1], **kwargs) + if add_dataloader_idx: + name = "val_loss" + else: + name = f"val_loss_custom_naming_{args[-1]}" + + self.log(name, output["x"], add_dataloader_idx=add_dataloader_idx) + return output + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer(default_root_dir=tmpdir, max_steps=5) + trainer.fit(model) + logged = trainer.logged_metrics + + # Check that the correct keys exist + if add_dataloader_idx: + assert 'val_loss/dataloader_idx_0' in logged + assert 'val_loss/dataloader_idx_1' in logged + else: + assert 'val_loss_custom_naming_0' in logged + assert 'val_loss_custom_naming_1' in logged diff --git a/tests/trainer/logging_/test_progress_bar_logging.py b/tests/trainer/logging_/test_progress_bar_logging.py new file mode 100644 index 00000000000000..75acbc2c509c30 --- /dev/null +++ b/tests/trainer/logging_/test_progress_bar_logging.py @@ -0,0 +1,23 @@ +import pytest + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +def test_logging_to_progress_bar_with_reserved_key(tmpdir): + """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ + + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + self.log("loss", output["loss"], prog_bar=True) + return output + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): + trainer.fit(model) diff --git a/tests/trainer/logging_/test_train_loop_logging_1_0.py b/tests/trainer/logging_/test_train_loop_logging_1_0.py new file mode 100644 index 00000000000000..e01d19f25a92c7 --- /dev/null +++ b/tests/trainer/logging_/test_train_loop_logging_1_0.py @@ -0,0 +1,940 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict (1.0) +""" + +import collections +import itertools +import os +from unittest import mock + +import numpy as np +import pytest +import torch +from torch.utils.data import Dataset + +import pytorch_lightning as pl +from pytorch_lightning import callbacks, Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.core.lightning import LightningModule +from tests.helpers.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset +from tests.helpers.deterministic_model import DeterministicModel +from tests.helpers.runif import RunIf + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__log(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + + # ----------- + # default + # ----------- + self.log('default', acc) + + # ----------- + # logger + # ----------- + # on_step T on_epoch F + self.log('l_s', acc, on_step=True, on_epoch=False, prog_bar=False, logger=True) + + # on_step F on_epoch T + self.log('l_e', acc, on_step=False, on_epoch=True, prog_bar=False, logger=True) + + # on_step T on_epoch T + self.log('l_se', acc, on_step=True, on_epoch=True, prog_bar=False, logger=True) + + # ----------- + # pbar + # ----------- + # on_step T on_epoch F + self.log('p_s', acc, on_step=True, on_epoch=False, prog_bar=True, logger=False) + + # on_step F on_epoch T + self.log('p_e', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False) + + # on_step T on_epoch T + self.log('p_se', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False) + + self.training_step_called = True + return acc + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + callbacks=[ModelCheckpoint(monitor='l_se')], + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'epoch', + 'default', + 'l_e', + 'l_s', + 'l_se_step', + 'l_se_epoch', + } + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = { + 'p_e', + 'p_s', + 'p_se_step', + 'p_se_epoch', + } + assert pbar_metrics == expected_pbar_metrics + + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.update({'p_se', 'l_se'}) + expected_callback_metrics.remove('epoch') + assert callback_metrics == expected_callback_metrics + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test__training_step__epoch_end__log(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(DeterministicModel): + + def training_step(self, batch, batch_idx): + self.training_step_called = True + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('a', acc, on_step=True, on_epoch=True) + self.log_dict({'a1': acc, 'a2': acc}) + return acc + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + self.log('b1', outputs[0]['loss']) + self.log('b', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True) + + def backward(self, loss, optimizer, optimizer_idx): + return LightningModule.backward(self, loss, optimizer, optimizer_idx) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = {'epoch', 'a_step', 'a_epoch', 'b', 'b1', 'a1', 'a2'} + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = {'b'} + assert pbar_metrics == expected_pbar_metrics + + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.remove('epoch') + expected_callback_metrics.add('a') + assert callback_metrics == expected_callback_metrics + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)]) +def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval, max_epochs): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.training_step_called = True + loss = self.step(batch[0]) + self.log('a', loss, on_step=True, on_epoch=True) + return loss + + def training_step_end(self, out): + self.training_step_end_called = True + self.log('b', out, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return out + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + self.log('c', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True) + self.log('d/e/f', 2) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + max_epochs=max_epochs, + log_every_n_steps=log_interval, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = {'a_step', 'a_epoch', 'b_step', 'b_epoch', 'c', 'd/e/f', 'epoch'} + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = {'c', 'b_epoch', 'b_step'} + assert pbar_metrics == expected_pbar_metrics + + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.update({'a', 'b'}) + expected_callback_metrics.remove('epoch') + assert callback_metrics == expected_callback_metrics + + # assert the loggers received the expected number + assert len(trainer.dev_debugger.logged_metrics) == ((batches / log_interval) * max_epochs) + max_epochs + + +@pytest.mark.parametrize(['batches', 'fx', 'result'], [(1, min, 0), (2, max, 1), (11, max, 10)]) +def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): + """ + Tests that log works correctly with different tensor types + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', torch.tensor(batch_idx).long(), on_step=False, on_epoch=True, reduce_fx=fx) + return acc + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', torch.tensor(batch_idx).float(), on_step=False, on_epoch=True, reduce_fx=fx) + return {"x": loss} + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + max_epochs=2, + weights_summary=None, + ) + trainer.fit(model) + + # make sure types are correct + assert trainer.logged_metrics['foo'] == result + assert trainer.logged_metrics['bar'] == result + + +def test_tbptt_log(tmpdir): + """ + Tests that only training_step can be used + """ + truncated_bptt_steps = 2 + sequence_size = 30 + batch_size = 30 + + x_seq = torch.rand(batch_size, sequence_size, 1) + y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist() + + class MockSeq2SeqDataset(torch.utils.data.Dataset): + + def __getitem__(self, i): + return x_seq, y_seq_list + + def __len__(self): + return 1 + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.test_hidden = None + self.layer = torch.nn.Linear(2, 2) + + def training_step(self, batch, batch_idx, hiddens): + assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" + self.test_hidden = torch.rand(1) + + x_tensor, y_list = batch + assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" + + y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype) + assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" + + pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) + loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) + + self.log('a', loss, on_epoch=True) + + return {'loss': loss, 'hiddens': self.test_hidden} + + def on_train_epoch_start(self) -> None: + self.test_hidden = None + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset=MockSeq2SeqDataset(), + batch_size=batch_size, + shuffle=False, + sampler=None, + ) + + model = TestModel() + model.training_epoch_end = None + model.example_input_array = torch.randn(5, truncated_bptt_steps) + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + limit_val_batches=0, + truncated_bptt_steps=truncated_bptt_steps, + max_epochs=2, + log_every_n_steps=2, + weights_summary=None, + ) + trainer.fit(model) + + generated = set(trainer.logged_metrics.keys()) + expected = {'a_step', 'a_epoch', 'epoch'} + assert generated == expected + + +def test_different_batch_types_for_sizing(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert isinstance(batch, dict) + a = batch['a'] + acc = self.step(a) + self.log('a', {'d1': 2, 'd2': torch.tensor(1)}, on_step=True, on_epoch=True) + return acc + + def validation_step(self, batch, batch_idx): + assert isinstance(batch, dict) + a = batch['a'] + output = self.layer(a) + loss = self.loss(batch, output) + self.log('n', {'d3': 2, 'd4': torch.tensor(1)}, on_step=True, on_epoch=True) + return {"x": loss} + + def train_dataloader(self): + return torch.utils.data.DataLoader(RandomDictDataset(32, 64), batch_size=32) + + def val_dataloader(self): + return torch.utils.data.DataLoader(RandomDictDataset(32, 64), batch_size=32) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + generated = set(trainer.logger_connector.logged_metrics) + expected = {'a_step', 'a_epoch', 'n_step/epoch_0', 'n_epoch', 'epoch'} + + assert generated == expected + + +def test_validation_step_with_string_data_logging(tmpdir): + + class TestModel(BoringModel): + + def on_train_epoch_start(self) -> None: + print("override any method to prove your bug") + + def training_step(self, batch, batch_idx): + output = self.layer(batch["x"]) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + output = self.layer(batch["x"]) + loss = self.loss(batch, output) + self.log("x", loss) + return {"x": loss} + + # fake data + train_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64)) + val_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64)) + + # model + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model, train_data, val_data) + + +def test_nested_datasouce_batch(tmpdir): + + class NestedDictStringDataset(Dataset): + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + x = { + 'post_text': ['bird is fast', 'big cat'], + 'dense_0': [ + torch.tensor([-0.1000, 0.2000], dtype=torch.float64), + torch.tensor([1, 1], dtype=torch.uint8), + ], + 'post_id': ['115', '116'], + 'label': [torch.tensor([0, 1]), torch.tensor([1, 1], dtype=torch.uint8)] + } + return x + + def __len__(self): + return self.len + + class TestModel(BoringModel): + + def on_train_epoch_start(self) -> None: + print("override any method to prove your bug") + + def training_step(self, batch, batch_idx): + output = self.layer(torch.rand(32)) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + output = self.layer(torch.rand(32)) + loss = self.loss(batch, output) + self.log("x", loss) + return {"x": loss} + + # fake data + train_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64)) + val_data = torch.utils.data.DataLoader(NestedDictStringDataset(32, 64)) + + # model + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model, train_data, val_data) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_log_works_in_train_callback(tmpdir): + """ + Tests that log can be called within callback + """ + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging( + self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[] + ): + self.funcs_called_count[func_name] += 1 + iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars])) + for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate): + # run logging + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log( + custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar + ) + + # catch information for verification + + # on on_train_start is outside the main loop. Won't be called + if func_name == "on_train_start": + self.callback_funcs_called[func_name].append([self.count * func_idx]) + + # Saved only values from second epoch, so we can compute its mean or latest. + if pl_module.trainer.current_epoch == 1: + self.callback_funcs_called[func_name].append([self.count * func_idx]) + + forked = on_step and on_epoch + + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": forked, + "func_name": func_name + } + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name + } + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name + } + + def on_train_start(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + + def on_train_epoch_start(self, trainer, pl_module): + self.make_logging( + pl_module, + 'on_train_epoch_start', + 3, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + + def on_batch_end(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + ) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging( + pl_module, + 'on_train_batch_end', + 7, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_train_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.make_logging( + pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + ) + + def on_epoch_end(self, trainer, pl_module): + self.make_logging( + pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + ) + + class TestModel(BoringModel): + + manual_loss = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.manual_loss.append(loss) + self.log('train_loss', loss) + return {"loss": loss} + + max_epochs = 2 + limit_train_batches = 2 + model = TestModel() + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback] + ) + trainer.fit(model) + + assert test_callback.funcs_called_count["on_train_start"] == 1 + assert test_callback.funcs_called_count["on_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_batch_end"] == 4 + assert test_callback.funcs_called_count["on_epoch_end"] == 2 + assert test_callback.funcs_called_count["on_train_batch_end"] == 4 + assert test_callback.funcs_called_count["on_epoch_end"] == 2 + assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] + assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] + trainer.callback_metrics.pop("train_loss") + + for func_name, output_value in trainer.callback_metrics.items(): + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics + + +def test_logging_sync_dist_true_cpu(tmpdir): + """ + Tests to ensure that the sync_dist flag works with CPU (should just return the original value) + """ + fake_result = 1 + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + return acc + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + return {"x": loss} + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.logged_metrics['foo'] == fake_result + assert trainer.logged_metrics['foo_2'] == 2 + assert trainer.logged_metrics['bar'] == fake_result + + +@RunIf(min_gpus=2, special=True) +def test_logging_sync_dist_true_ddp(tmpdir): + """ + Tests to ensure that the sync_dist flag works with ddp + """ + + class TestLoggingSyncDistModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM') + self.log('cho', acc, on_step=False, on_epoch=True) + return acc + + def validation_step(self, batch, batch_idx): + self.training_step_called = True + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG') + return {"x": loss} + + model = TestLoggingSyncDistModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + weights_summary=None, + accelerator="ddp", + gpus=2, + profiler="pytorch" + ) + + if os.getenv("LOCAL_RANK") == '0': + with pytest.warns(UserWarning, match="The value associated to the key cho:"): + trainer.fit(model) + else: + trainer.fit(model) + assert trainer.logged_metrics['foo'] == 2 + assert trainer.logged_metrics['bar'] == 2 + + +@RunIf(min_gpus=1) +def test_logging_sync_dist_true_gpu(tmpdir): + """ + Tests to ensure that the sync_dist flag works with GPU (should just return the original value) + """ + fake_result = 1 + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + return acc + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + return {"x": loss} + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + gpus=1, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.logged_metrics['foo'] == fake_result + assert trainer.logged_metrics['bar'] == fake_result + + +def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True) + return super().training_step(*args) + + def on_train_epoch_end(self, *_): + self.on_train_epoch_end_called = True + self.epoch_end_called = True + self.log( + 'foo_2', + torch.tensor(self.current_epoch), + prog_bar=True, + on_epoch=True, + sync_dist=True, + sync_dist_op='sum' + ) + + def on_epoch_end(self): + self.epoch_end_called = True + assert self.trainer.progress_bar_dict["foo"] == self.current_epoch + assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=1, + limit_val_batches=0, + checkpoint_callback=False, + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + model = TestModel() + trainer.fit(model) + assert model.epoch_end_called + assert model.on_train_epoch_end_called + + +def test_logging_in_callbacks_with_log_function(tmpdir): + """ + Tests ensure self.log can be used directly in callbacks. + """ + + class LoggingCallback(callbacks.Callback): + + def on_train_start(self, trainer, pl_module): + self.log("on_train_start", 1) + + def on_train_epoch_start(self, trainer, pl_module): + self.log("on_train_epoch_start", 2) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.log("on_train_batch_end", 3) + + def on_batch_end(self, trainer, pl_module): + self.log("on_batch_end", 4) + + def on_epoch_end(self, trainer, pl_module): + self.log("on_epoch_end", 5) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.log("on_train_epoch_end", 6) + self.callback_metrics = trainer.logger_connector.callback_metrics + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + callbacks=[LoggingCallback()] + ) + trainer.fit(model) + + expected = { + 'on_train_start': 1, + 'on_train_epoch_start': 2, + 'on_train_batch_end': 3, + 'on_batch_end': 4, + 'on_epoch_end': 5, + 'on_train_epoch_end': 6 + } + assert trainer.callback_metrics == expected + + +@RunIf(min_gpus=1) +def test_metric_are_properly_reduced(tmpdir): + + class TestingModel(BoringModel): + + def __init__(self, *args, **kwargs): + super().__init__() + self.val_acc = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + self.log("train_loss", output["loss"]) + return output + + def validation_step(self, batch, batch_idx): + preds = torch.tensor([[0.9, 0.1]], device=self.device) + targets = torch.tensor([1], device=self.device) + if batch_idx < 8: + preds = torch.tensor([[0.1, 0.9]], device=self.device) + self.val_acc(preds, targets) + self.log('val_acc', self.val_acc, on_step=True, on_epoch=True) + return super().validation_step(batch, batch_idx) + + early_stop = EarlyStopping(monitor='val_acc', mode='max') + + checkpoint = ModelCheckpoint( + monitor='val_acc', + save_last=True, + save_top_k=2, + mode='max', + ) + + model = TestingModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + max_epochs=2, + limit_train_batches=5, + limit_val_batches=32, + callbacks=[early_stop, checkpoint] + ) + trainer.fit(model) + + assert trainer.callback_metrics["val_acc"] == 8 / 32. + assert "train_loss" in trainer.callback_metrics diff --git a/tests/trainer/optimization/__init__.py b/tests/trainer/optimization/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/optimization/test_backward_calls.py b/tests/trainer/optimization/test_backward_calls.py new file mode 100644 index 00000000000000..f53096588ff204 --- /dev/null +++ b/tests/trainer/optimization/test_backward_calls.py @@ -0,0 +1,52 @@ +from unittest.mock import patch + +import pytest + +from pytorch_lightning import Trainer +from tests.base import EvalModelTemplate + + +@pytest.mark.parametrize("num_steps", [1, 2, 3]) +@patch("torch.Tensor.backward") +def test_backward_count_simple(torch_backward, num_steps): + """ Test that backward is called exactly once per step. """ + model = EvalModelTemplate() + trainer = Trainer(max_steps=num_steps) + trainer.fit(model) + assert torch_backward.call_count == num_steps + + torch_backward.reset_mock() + + trainer.test(model) + assert torch_backward.call_count == 0 + + +@patch("torch.Tensor.backward") +def test_backward_count_with_grad_accumulation(torch_backward): + """ Test that backward is called the correct number of times when accumulating gradients. """ + model = EvalModelTemplate() + trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2) + trainer.fit(model) + assert torch_backward.call_count == 6 + + torch_backward.reset_mock() + + trainer = Trainer(max_steps=6, accumulate_grad_batches=2) + trainer.fit(model) + assert torch_backward.call_count == 12 + + +@patch("torch.Tensor.backward") +def test_backward_count_with_closure(torch_backward): + """ Using a closure (e.g. with LBFGS) should lead to no extra backward calls. """ + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__lbfgs + trainer = Trainer(max_steps=5) + trainer.fit(model) + assert torch_backward.call_count == 5 + + torch_backward.reset_mock() + + trainer = Trainer(max_steps=5, accumulate_grad_batches=2) + trainer.fit(model) + assert torch_backward.call_count == 10 diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py new file mode 100644 index 00000000000000..e197e9b35adc90 --- /dev/null +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -0,0 +1,1150 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import os +from copy import deepcopy +from unittest import mock +from unittest.mock import ANY, call, patch + +import pytest +import torch +import torch.distributed as torch_distrib +import torch.nn.functional as F + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import Callback +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_multiple_optimizers_manual(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt_a, opt_b = self.optimizers() + loss_1 = self.step(batch[0]) + + # make sure there are no grads + if batch_idx > 0: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_multiple_optimizers_manual_return(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt_a, opt_b = self.optimizers() + loss_1 = self.step(batch[0]) + + # make sure there are no grads + if batch_idx > 0: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + return {'something': 'else'} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_multiple_optimizers_manual_return_and_log(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt_a, opt_b = self.optimizers() + loss_1 = self.step(batch[0]) + + # make sure there are no grads + if batch_idx > 0: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + self.log('a', loss_2, on_epoch=True) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + return {'something': 'else'} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + expected = {'a_step', 'a_epoch', 'epoch'} + logged = set(trainer.logged_metrics.keys()) + assert expected == logged + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=1) +def test_multiple_optimizers_manual_native_amp(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt_a, opt_b = self.optimizers() + loss_1 = self.step(batch[0]) + + # make sure there are no grads + if batch_idx > 0: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + precision=16, + gpus=1, + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=1, amp_apex=True) +def test_multiple_optimizers_manual_apex(tmpdir): + """ + Tests that only training_step can be used + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt_a, opt_b = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + # make sure there are no grads + if batch_idx > 0: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + # fake discriminator + loss_2 = self(x) + loss_2 = self.loss(loss_2, loss_2) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, retain_graph=True) + self.manual_backward(loss_2) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + assert torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + precision=16, + amp_level='O2', + amp_backend='apex', + gpus=1 + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + +class ManualOptimizationExtendedModel(BoringModel): + + count = 0 + called = collections.defaultdict(int) + detach = False + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + @property + def should_update(self): + return self.count % 2 == 0 + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_start"] += 1 + self.weight_before = self.layer.weight.clone() + + def training_step(self, batch, batch_idx): + self.called["training_step"] += 1 + opt = self.optimizers() + output = self.layer(batch) + + loss = self.loss(batch, output) + loss /= loss.clone().detach() + loss *= 0.1 + + if self.should_update: + + self.manual_backward(loss, opt) + opt.step() + opt.zero_grad() + + return loss.detach() if self.detach else loss + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_end"] += 1 + after_before = self.layer.weight.clone() + if self.should_update: + try: + assert not torch.equal(self.weight_before, after_before), self.count + # todo: specify the possible exception + except Exception: + # TODO: Figure out why 1 every 3 runs, weights don't get updated on count = 4" + pass + else: + try: + assert torch.equal(self.weight_before, after_before) + # todo: specify the possible exception + except Exception: + # almost no diff between before and after + assert torch.abs(torch.sum(self.weight_before) - torch.sum(after_before)).item() < 10e-6 + assert torch.all(self.layer.weight.grad == 0) + self.count += 1 + + def on_train_end(self): + assert self.called["training_step"] == 10 + assert self.called["on_train_batch_start"] == 10 + assert self.called["on_train_batch_end"] == 10 + + +@RunIf(min_gpus=2) +def test_manual_optimization_and_return_tensor(tmpdir): + """ + This test verify that in `manual_optimization` + we don't add gradient when the user return loss in `training_step` + """ + + model = ManualOptimizationExtendedModel() + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=10, + limit_test_batches=0, + limit_val_batches=0, + precision=16, + amp_backend='native', + accelerator="ddp_spawn", + gpus=2, + ) + trainer.fit(model) + + +@RunIf(min_gpus=2) +def test_manual_optimization_and_return_detached_tensor(tmpdir): + """ + This test verify that in `manual_optimization` + we don't add gradient when the user return loss in `training_step` + When the tensor is detached, return MisConfiguration Error. + """ + + model = ManualOptimizationExtendedModel() + model.detach = True + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=10, + limit_test_batches=0, + limit_val_batches=0, + precision=16, + amp_backend='native', + accelerator="ddp_spawn", + gpus=2, + ) + expected_message = "In manual optimization, `training_step` should not return a Tensor" + with pytest.raises(Exception, match=expected_message): + trainer.fit(model) + + +@RunIf(min_gpus=1) +def test_manual_optimization_and_accumulated_gradient(tmpdir): + """ + This test verify that in `automatic_optimization=False`, + step is being called only when we shouldn't accumulate. + """ + seed_everything(234) + + class ExtendedModel(BoringModel): + + count = 1 + called = collections.defaultdict(int) + detach = False + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + @property + def should_update(self): + return self.count % 2 == 0 + + @property + def should_have_updated(self): + return self.count % 4 == 0 + + @property + def has_gradient(self): + return self.layer.weight.grad is not None + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_start"] += 1 + self.weight_before = self.layer.weight.clone() + + def training_step(self, batch, batch_idx): + self.called["training_step"] += 1 + opt = self.optimizers() + output = self.layer(batch) + + loss = self.loss(batch, output) + loss /= loss.clone().detach() + loss *= 0.1 + + if self.should_update: + + self.manual_backward(loss, opt) + if self.should_have_updated: + opt.step() + opt.zero_grad() + + return loss.detach() if self.detach else loss + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_end"] += 1 + after_before = self.layer.weight.clone() + if self.should_update and self.should_have_updated: + assert not torch.equal(self.weight_before, after_before), self.count + assert torch.all(self.layer.weight.grad == 0) + else: + assert torch.equal(self.weight_before, after_before) + if self.count > 1: + if self.count % 4 == 1: + assert torch.all(self.layer.weight.grad == 0) + else: + assert torch.sum(self.layer.weight.grad) != 0 + self.count += 1 + + def on_train_epoch_end(self, *_, **__): + assert self.called["training_step"] == 20 + assert self.called["on_train_batch_start"] == 20 + assert self.called["on_train_batch_end"] == 20 + + model = ExtendedModel() + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=20, + limit_test_batches=0, + limit_val_batches=0, + precision=16, + amp_backend='native', + accumulate_grad_batches=4, + gpus=1, + ) + trainer.fit(model) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@RunIf(min_gpus=1) +def test_multiple_optimizers_step(tmpdir): + """ + Tests that `step` works with several optimizers + """ + + class TestModel(BoringModel): + + called = False + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def on_after_backward(self): + self.called = True + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 100, norm.item() + + def training_step(self, batch, batch_idx): + # manual + opt_a, opt_b = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + # make sure there are no grads + if self.layer.weight.grad is not None: + assert torch.all(self.layer.weight.grad == 0) + + self.manual_backward(loss_1, opt_a) + opt_a.step() + + # fake discriminator + loss_2 = self(x) + loss_2 = self.loss(loss_2, loss_2) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + + assert self.layer.weight.grad is not None + opt_b.step() + opt_b.zero_grad() + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + precision=16, + amp_backend='native', + gpus=1, + ) + + trainer.fit(model) + + num_manual_backward_calls = 3 + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + assert model.called + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_step_with_optimizer_closure(tmpdir): + """ + Tests that `step` works with optimizer_closure + """ + + class TestModel(BoringModel): + + _losses = [] + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + + # make sure there are no grads + if self.layer.weight.grad is not None: + assert torch.all(self.layer.weight.grad == 0) + + opt = self.optimizers() + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss = self.loss(None, predictions) + return loss + + def optimizer_closure(): + # emulate bayesian optimization. + num_backward = 2 + losses = [] + for backward_idx in range(num_backward): + loss = compute_loss() + losses.append(loss) + retain_graph = (num_backward - 1) != backward_idx + self.manual_backward(loss, opt, retain_graph=retain_graph) + # emulate MC dropout training + loss = torch.stack(losses).mean() + self._losses.append(loss) + self.log("train_loss", loss, on_step=True, prog_bar=True, on_epoch=True) + assert losses[0] != losses[1] + + weight_before = self.layer.weight.clone() + + opt.step(closure=optimizer_closure) + opt.zero_grad() + + weight_after = self.layer.weight.clone() + assert not torch.equal(weight_before, weight_after) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + ) + + trainer.fit(model) + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2 + assert trainer.logger_connector.progress_bar_metrics["train_loss_step"] == model._losses[-1] + assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean() + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir): + """ + Tests that `step` works with optimizer_closure and accumulated_grad + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + def optimizer_closure(): + # emulate bayesian optimization. + num_backward = 1 + for backward_idx in range(num_backward + 1): + retain_graph = num_backward != backward_idx # noqa E225 + self.manual_backward(loss_1, opt, retain_graph=retain_graph) + + weight_before = self.layer.weight.clone() + + opt.step(closure=optimizer_closure) + + weight_after = self.layer.weight.clone() + if not self.trainer.train_loop.should_accumulate(): + assert not torch.equal(weight_before, weight_after) + else: + assert self.layer.weight.grad is not None + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 4 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + + trainer.fit(model) + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2 + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@patch("torch.optim.SGD.step") +def test_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir): + """ + Tests that `step` works with optimizer_closure and extra arguments + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # manual + opt = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + def optimizer_closure(): + # emulate bayesian optimization. + num_backward = 1 + for backward_idx in range(num_backward + 1): + retain_graph = num_backward != backward_idx # noqa E225 + self.manual_backward(loss_1, opt, retain_graph=retain_graph) + + opt.step(closure=optimizer_closure) + opt.zero_grad() + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 4 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + + trainer.fit(model) + expected_calls = [call(closure=ANY) for s in range(2)] + step_mock.assert_has_calls(expected_calls) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@patch("torch.optim.Adam.step") +@patch("torch.optim.SGD.step") +def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir): + """ + Tests that `step` works with optimizer_closure and different accumulated_gradient frequency + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + + # emulate gans training + opt_gen, opt_dis = self.optimizers() + + # Note: Be careful, don't log on the same key in self.log in both closure + # as they will be aggregated together on epoch_end + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss = self.loss(None, predictions) + return loss + + def gen_closure(): + loss_gen = compute_loss() + self.log("loss_gen", loss_gen, on_step=True, on_epoch=True) + self.manual_backward(loss_gen, opt_gen) + + def dis_closure(): + loss_dis = compute_loss() + self.log("loss_dis", loss_dis, on_step=True, on_epoch=True) + self.manual_backward(loss_dis, opt_dis) + + # this will accumulate gradients for 2 batches and then call opt_gen.step() + gen_closure() + if batch_idx % 2 == 0: + opt_gen.step(closure=gen_closure, optim='sgd') + opt_gen.zero_grad() + + # update discriminator every 4 baches + # therefore, no gradient accumulation for discriminator + if batch_idx % 4 == 0: + opt_dis.step(closure=dis_closure) + opt_dis.zero_grad() + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) + return [optimizer_gen, optimizer_dis] + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 8 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + + trainer.fit(model) + expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)] + mock_sgd_step.assert_has_calls(expected_calls) + expected_calls = [call(closure=ANY) for s in range(2)] + mock_adam_step.assert_has_calls(expected_calls) + + +class TestManualOptimizationDDPCallack(Callback): + + def on_train_end(self, trainer, pl_module): + + opt_a, opt_b = pl_module.optimizers() + assert opt_a._total_optimizer_step_calls == 4 + assert opt_b._total_optimizer_step_calls == 2 + + +class TesManualOptimizationDDPModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def loss_ones(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def loss_zeros(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.zeros_like(prediction)) + + def manual_sync_grad(self) -> bool: + torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False) + return True + + def training_step(self, batch, batch_idx): + + # emulate gans training + opt_gen, opt_dis = self.optimizers() + + # Note: Be careful, don't log on the same key in self.log in both closure + # as they will be aggregated together on epoch_end + + world_size = torch_distrib.get_world_size(torch_distrib.group.WORLD) + assert world_size == 2 + + make_gen_optimizer_step = batch_idx % 2 == 1 + make_dis_optimizer_step = batch_idx % 4 == 0 + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss_ones = self.loss_ones(None, predictions) + loss_zeros = self.loss_zeros(None, predictions) + return loss_ones, loss_zeros + + def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True): + self.manual_backward(loss, opt, retain_graph=retain_graph) + if make_optimizer_step: + grad_clone = self.layer.weight.grad.clone() + assert self.manual_sync_grad() + self.layer.weight.grad /= world_size + assert torch.equal(self.layer.weight.grad, grad_clone) + + def gen_closure(): + loss_ones_gen, loss_zeros = compute_loss() + make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step) + make_manual_backward(loss_ones_gen, opt_gen, make_optimizer_step=make_gen_optimizer_step) + + def dis_closure(): + loss_ones_gen, loss_zeros = compute_loss() + make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True, make_optimizer_step=make_dis_optimizer_step) + make_manual_backward(loss_ones_gen, opt_dis, make_optimizer_step=make_dis_optimizer_step) + + # this will accumulate gradients for 2 batches and then call opt_gen.step() + if make_gen_optimizer_step: + opt_gen.step(closure=gen_closure) + opt_gen.zero_grad() + + # update discriminator every 4 baches + # therefore, no gradient accumulation for discriminator + if make_dis_optimizer_step: + opt_dis.step(closure=dis_closure) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) + return [optimizer_gen, optimizer_dis] + + +def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizationDDPModel): + + seed_everything(42) + + model = model_cls() + model_copy = deepcopy(model) + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 8 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + gpus=2, + accelerator=accelerator, + callbacks=[TestManualOptimizationDDPCallack()] + ) + + trainer.fit(model) + + for param, param_copy in zip(model.parameters(), model_copy.parameters()): + assert not torch.equal(param.cpu().data, param_copy.data) + + +@RunIf(min_gpus=2, special=True) +def test_step_with_optimizer_closure_with_different_frequencies_ddp(tmpdir): + """ + Tests that `step` works with optimizer_closure and different accumulated_gradient frequency + """ + + train_manual_optimization(tmpdir, "ddp") + + +@RunIf(min_gpus=2) +def test_step_with_optimizer_closure_with_different_frequencies_ddp_spawn(tmpdir): + """ + Tests that `step` works with optimizer_closure and different accumulated_gradient frequency + """ + + train_manual_optimization(tmpdir, "ddp_spawn") + + +class TestManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel): + + def training_step(self, batch, batch_idx): + + # emulate gans training + opt_gen, opt_dis = self.optimizers() + + # Note: Be careful, don't log on the same key in self.log in both closure + # as they will be aggregated together on epoch_end + + world_size = torch_distrib.get_world_size(torch_distrib.group.WORLD) + assert world_size == 2 + + make_gen_optimizer_step = batch_idx % 2 == 1 + make_dis_optimizer_step = batch_idx % 4 == 0 + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss_ones = self.loss_ones(None, predictions) + loss_zeros = self.loss_zeros(None, predictions) + return loss_ones, loss_zeros + + def make_manual_backward(loss, opt, retain_graph=False, make_optimizer_step=True): + self.manual_backward(loss, opt, retain_graph=retain_graph) + if make_optimizer_step: + grad_clone = self.layer.weight.grad.clone() + assert self.manual_sync_grad() + self.layer.weight.grad /= world_size + assert torch.equal(self.layer.weight.grad, grad_clone) + + def gen_closure(): + loss_ones_gen, loss_zeros = compute_loss() + make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step) + make_manual_backward(loss_ones_gen, opt_gen, make_optimizer_step=make_gen_optimizer_step) + + def dis_closure(): + loss_ones_gen, loss_zeros = compute_loss() + make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True, make_optimizer_step=make_dis_optimizer_step) + make_manual_backward(loss_ones_gen, opt_dis, make_optimizer_step=make_dis_optimizer_step) + + # this will accumulate gradients for 2 batches and then call opt_gen.step() + with opt_gen.toggle_model(sync_grad=make_gen_optimizer_step): + gen_closure() + if make_gen_optimizer_step: + opt_gen.step() + opt_gen.zero_grad() + + with opt_dis.toggle_model(sync_grad=make_dis_optimizer_step): + dis_closure() + if make_dis_optimizer_step: + opt_dis.step() + opt_dis.zero_grad() + + +@RunIf(min_gpus=2, special=True) +def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir): + train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel) diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py new file mode 100644 index 00000000000000..5f0ca34015df06 --- /dev/null +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -0,0 +1,169 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the behaviours related to multiple optimizers works +""" +import pytest +import torch + +import pytorch_lightning as pl +from tests.helpers.boring_model import BoringModel + + +class MultiOptModel(BoringModel): + + def configure_optimizers(self): + opt_a = torch.optim.SGD(self.layer.parameters(), lr=0.001) + opt_b = torch.optim.SGD(self.layer.parameters(), lr=0.001) + return opt_a, opt_b + + +def test_unbalanced_logging_with_multiple_optimizers(tmpdir): + """ + This tests ensures reduction works in unbalanced logging settings, + even when a Callback also logs. + """ + + class TestModel(MultiOptModel): + + actual = {0: [], 1: []} + + def training_step(self, batch, batch_idx, optimizer_idx): + out = super().training_step(batch, batch_idx) + loss = out["loss"] + self.log(f"loss_{optimizer_idx}", loss, on_epoch=True) + self.actual[optimizer_idx].append(loss) + return out + + model = TestModel() + model.training_epoch_end = None + + class TestCallback(pl.Callback): + + def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_idx): + # when this is called, the EpochResultStore state has not been reset yet because we are still + # "INSIDE_BATCH_TRAIN_LOOP" and the LoggerConnector runs its `on_train_batch_end` after the + # Callback (see `TrainLoop.on_train_batch_end`). For this reason, opt_idx here is the index + # of the last optimizer updated (the second, index 1). This produced a KeyError as reported in #5459 + pl_module.log("test_train_batch_end", trainer.logger_connector.cached_results._opt_idx) + + # Initialize a trainer + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=5, + callbacks=[TestCallback()], + weights_summary=None, + ) + trainer.fit(model) + + for k, v in model.actual.items(): + assert torch.equal(trainer.callback_metrics[f"loss_{k}_step"], v[-1]) + # test loss is properly reduced + torch.testing.assert_allclose(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean()) + + assert trainer.callback_metrics["test_train_batch_end"] == len(model.optimizers()) - 1 + + +def test_multiple_optimizers(tmpdir): + + class TestModel(MultiOptModel): + + seen = [False, False] + + def training_step(self, batch, batch_idx, optimizer_idx): + self.seen[optimizer_idx] = True + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + model = TestModel() + model.val_dataloader = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + assert all(model.seen) + + +def test_multiple_optimizers_manual(tmpdir): + + class TestModel(MultiOptModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + + # manual optimization + opt_a, opt_b = self.optimizers() + loss_1 = self.step(batch[0]) + + # fake generator + self.manual_backward(loss_1, opt_a) + opt_a.step() + opt_a.zero_grad() + + # fake discriminator + loss_2 = self.step(batch[0]) + self.manual_backward(loss_2, opt_b) + opt_b.step() + opt_b.zero_grad() + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + model = TestModel() + model.val_dataloader = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + assert model.training_step_called + + +def test_multiple_optimizers_no_opt_idx_argument(tmpdir): + """ + Test that an error is raised if no optimizer_idx is present when + multiple optimizeres are passed in case of automatic_optimization + """ + + class TestModel(MultiOptModel): + + def training_step(self, batch, batch_idx): + return super().training_step(batch, batch_idx) + + trainer = pl.Trainer(default_root_dir=tmpdir, fast_dev_run=2) + + with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'): + trainer.fit(TestModel()) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py new file mode 100644 index 00000000000000..f13448187364c4 --- /dev/null +++ b/tests/trainer/optimization/test_optimizers.py @@ -0,0 +1,478 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate +from tests.helpers.boring_model import BoringModel + + +def test_optimizer_with_scheduling(tmpdir): + """ Verify that learning rate scheduling is working """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + model.configure_optimizers = model.configure_optimizers__single_scheduler + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + val_check_interval=0.5, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + init_lr = hparams.get('learning_rate') + adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] + + assert len(trainer.lr_schedulers) == 1, \ + 'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) + + assert all(a == adjusted_lr[0] for a in adjusted_lr), \ + 'Lr not equally adjusted for all param groups' + adjusted_lr = adjusted_lr[0] + + assert init_lr * 0.1 == adjusted_lr, \ + 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr) + + +def test_multi_optimizer_with_scheduling_stepping(tmpdir): + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + init_lr = hparams.get('learning_rate') + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] + + assert len(trainer.lr_schedulers) == 2, 'all lr scheduler not initialized properly' + + assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \ + 'lr not equally adjusted for all param groups for optimizer 1' + adjusted_lr1 = adjusted_lr1[0] + + assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \ + 'lr not equally adjusted for all param groups for optimizer 2' + adjusted_lr2 = adjusted_lr2[0] + + # Called ones after end of epoch + assert init_lr * 0.1 == adjusted_lr1, 'lr for optimizer 1 not adjusted correctly' + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times + assert init_lr * 0.1 == adjusted_lr2, 'lr for optimizer 2 not adjusted correctly' + + +def test_reducelronplateau_with_no_monitor_raises(tmpdir): + """ + Test exception when a ReduceLROnPlateau is used with no monitor + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: ([optimizer], [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)]) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises( + MisconfigurationException, match='`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`' + ): + trainer.fit(model) + + +def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir): + """ + Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match='must include a monitor when a `ReduceLROnPlateau`'): + trainer.fit(model) + + +def test_reducelronplateau_scheduling(tmpdir): + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': 'val_acc', + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + lr_scheduler = trainer.lr_schedulers[0] + assert lr_scheduler == dict( + scheduler=lr_scheduler['scheduler'], + monitor='val_acc', + interval='epoch', + frequency=1, + reduce_on_plateau=True, + strict=True, + name=None, + ), 'lr scheduler was not correctly converted to dict' + + +def test_optimizer_return_options(): + trainer = Trainer() + model = EvalModelTemplate() + + # single optimizer + opt_a = torch.optim.Adam(model.parameters(), lr=0.002) + opt_b = torch.optim.SGD(model.parameters(), lr=0.002) + scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) + scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) + + # single optimizer + model.configure_optimizers = lambda: opt_a + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 1 and len(lr_sched) == len(freq) == 0 + + # opt tuple + model.configure_optimizers = lambda: (opt_a, opt_b) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert optim == [opt_a, opt_b] + assert len(lr_sched) == len(freq) == 0 + + # opt list + model.configure_optimizers = lambda: [opt_a, opt_b] + optim, lr_sched, freq = trainer.init_optimizers(model) + assert optim == [opt_a, opt_b] + assert len(lr_sched) == len(freq) == 0 + + ref_lr_sched = dict( + scheduler=scheduler_a, + interval='epoch', + frequency=1, + reduce_on_plateau=False, + monitor=None, + strict=True, + name=None, + ) + + # opt tuple of 2 lists + model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == len(lr_sched) == 1 + assert len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == ref_lr_sched + + # opt tuple of 1 list + model.configure_optimizers = lambda: ([opt_a], scheduler_a) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == len(lr_sched) == 1 + assert len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == ref_lr_sched + + # opt single dictionary + model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == len(lr_sched) == 1 + assert len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == ref_lr_sched + + # opt multiple dictionaries with frequencies + model.configure_optimizers = lambda: ( + { + "optimizer": opt_a, + "lr_scheduler": scheduler_a, + "frequency": 1 + }, + { + "optimizer": opt_b, + "lr_scheduler": scheduler_b, + "frequency": 5 + }, + ) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == len(lr_sched) == len(freq) == 2 + assert optim[0] == opt_a + assert lr_sched[0] == ref_lr_sched + assert freq == [1, 5] + + +def test_none_optimizer_warning(): + + trainer = Trainer() + + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__empty + + with pytest.warns(UserWarning, match='will run with no optimizer'): + _, __, ___ = trainer.init_optimizers(model) + + +def test_none_optimizer(tmpdir): + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + model.configure_optimizers = model.configure_optimizers__empty + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + trainer.fit(model) + + # verify training completed + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +def test_configure_optimizer_from_dict(tmpdir): + """Tests if `configure_optimizer` method could return a dictionary with `optimizer` field only.""" + + class CurrentModel(EvalModelTemplate): + + def configure_optimizers(self): + config = {'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)} + return config + + hparams = EvalModelTemplate.get_default_hparams() + model = CurrentModel(**hparams) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +def test_configure_optimizers_with_frequency(tmpdir): + """ + Test that multiple optimizers work when corresponding frequency is set. + """ + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__multiple_optimizers_frequency + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_init_optimizers_during_evaluation(tmpdir, fn): + """ + Test that optimizers is an empty list during evaluation + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) + optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=1) + lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=10, limit_test_batches=10) + validate_or_test = getattr(trainer, fn) + validate_or_test(TestModel(), ckpt_path=None) + + assert len(trainer.lr_schedulers) == 0 + assert len(trainer.optimizers) == 0 + assert len(trainer.optimizer_frequencies) == 0 + + +def test_multiple_optimizers_callbacks(tmpdir): + """ + Tests that multiple optimizers can be used with callbacks + """ + + class CB(Callback): + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + pass + + def on_train_epoch_start(self, trainer, pl_module): + pass + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer_1 = torch.nn.Linear(32, 2) + self.layer_2 = torch.nn.Linear(32, 2) + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + a = batch[0] + acc = self.layer_1(a) + else: + a = batch[0] + acc = self.layer_2(a) + + acc = self.loss(acc, acc) + return acc + + def configure_optimizers(self): + a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2) + b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2) + return a, b + + model = TestModel() + model.training_epoch_end = None + trainer = Trainer( + callbacks=[CB()], + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + +def test_lr_scheduler_strict(tmpdir): + """ + Test "strict" support in lr_scheduler dict + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'monitor': 'giraffe', + 'strict': True + }, + } + with pytest.raises( + MisconfigurationException, + match=r'ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:', + ): + trainer.fit(model) + + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'monitor': 'giraffe', + 'strict': False, + }, + } + with pytest.warns( + RuntimeWarning, match=r'ReduceLROnPlateau conditioned on metric .* which is not available but strict' + ): + assert trainer.fit(model) + + +def test_unknown_configure_optimizers_raises(tmpdir): + """ + Test exception with an unsupported configure_optimizers return + """ + model = EvalModelTemplate() + model.configure_optimizers = lambda: 1 + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match="Unknown configuration for model optimizers"): + trainer.fit(model) + + +def test_lr_scheduler_with_unknown_interval_raises(tmpdir): + """ + Test exception when lr_scheduler dict has unknown interval param value + """ + model = BoringModel() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'interval': "incorrect_unknown_value" + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'): + trainer.fit(model) + + +def test_lr_scheduler_with_extra_keys_warns(tmpdir): + """ + Test warning when lr_scheduler dict has extra keys + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'foo': 1, + 'bar': 2, + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.warns(RuntimeWarning, match=r'Found unsupported keys in the lr scheduler dict: \[.+\]'): + trainer.fit(model) + + +def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir): + """ + Test exception when lr_scheduler dict has no scheduler + """ + model = EvalModelTemplate() + model.configure_optimizers = lambda: { + 'optimizer': torch.optim.Adam(model.parameters()), + 'lr_scheduler': {}, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match='The lr scheduler dict must have the key "scheduler"'): + trainer.fit(model) + + +def test_invalid_optimizer_in_scheduler(tmpdir): + """ + Test exception when optimizer attatched to lr_schedulers wasn't returned + """ + + class InvalidOptimizerModel(BoringModel): + + def configure_optimizers(self): + opt1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + opt2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(opt2, step_size=1) + return [opt1], [lr_scheduler] + + model = InvalidOptimizerModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match="attatched with an optimizer that wasn't returned"): + trainer.fit(model) diff --git a/tests/trainer/properties/__init__.py b/tests/trainer/properties/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/log_dir.py new file mode 100644 index 00000000000000..730e2a1512c23e --- /dev/null +++ b/tests/trainer/properties/log_dir.py @@ -0,0 +1,142 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from tests.helpers.boring_model import BoringModel + + +class TestModel(BoringModel): + + def __init__(self, expected_log_dir): + super().__init__() + self.expected_log_dir = expected_log_dir + + def training_step(self, *args, **kwargs): + assert self.trainer.log_dir == self.expected_log_dir + return super().training_step(*args, **kwargs) + + +def test_logdir(tmpdir): + """ + Tests that the path is correct when checkpoint and loggers are used + """ + expected = os.path.join(tmpdir, 'lightning_logs', 'version_0') + + model = TestModel(expected) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + ) + + assert trainer.log_dir == expected + trainer.fit(model) + assert trainer.log_dir == expected + + +def test_logdir_no_checkpoint_cb(tmpdir): + """ + Tests that the path is correct with no checkpoint + """ + expected = os.path.join(tmpdir, 'lightning_logs', 'version_0') + model = TestModel(expected) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + checkpoint_callback=False, + ) + + assert trainer.log_dir == expected + trainer.fit(model) + assert trainer.log_dir == expected + + +def test_logdir_no_logger(tmpdir): + """ + Tests that the path is correct even when there is no logger + """ + expected = os.path.join(tmpdir) + model = TestModel(expected) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + logger=False, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + ) + + assert trainer.log_dir == expected + trainer.fit(model) + assert trainer.log_dir == expected + + +def test_logdir_no_logger_no_checkpoint(tmpdir): + """ + Tests that the path is correct even when there is no logger + """ + expected = os.path.join(tmpdir) + model = TestModel(expected) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + logger=False, + checkpoint_callback=False, + ) + + assert trainer.log_dir == expected + trainer.fit(model) + assert trainer.log_dir == expected + + +def test_logdir_custom_callback(tmpdir): + """ + Tests that the path is correct even when there is a custom callback + """ + expected = os.path.join(tmpdir, 'lightning_logs', 'version_0') + model = TestModel(expected) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + callbacks=[ModelCheckpoint(dirpath=os.path.join(tmpdir, 'ckpts'))], + ) + + assert trainer.log_dir == expected + trainer.fit(model) + assert trainer.log_dir == expected + + +def test_logdir_custom_logger(tmpdir): + """ + Tests that the path is correct even when there is a custom logger + """ + expected = os.path.join(tmpdir, 'custom_logs', 'version_0') + model = TestModel(expected) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + logger=TensorBoardLogger(save_dir=tmpdir, name='custom_logs') + ) + + assert trainer.log_dir == expected + trainer.fit(model) + assert trainer.log_dir == expected diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py new file mode 100644 index 00000000000000..5dc1ea5de4e8a0 --- /dev/null +++ b/tests/trainer/properties/test_get_model.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +class TrainerGetModel(BoringModel): + + def on_fit_start(self): + assert self == self.trainer.lightning_module + + def on_fit_end(self): + assert self == self.trainer.lightning_module + + +def test_get_model(tmpdir): + """ + Tests that `trainer.lightning_module` extracts the model correctly + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + +@RunIf(skip_windows=True) +def test_get_model_ddp_cpu(tmpdir): + """ + Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + accelerator='ddp_cpu', + num_processes=2, + ) + trainer.fit(model) + + +@RunIf(min_gpus=1) +def test_get_model_gpu(tmpdir): + """ + Tests that `trainer.lightning_module` extracts the model correctly when using GPU + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + gpus=1, + ) + trainer.fit(model) diff --git a/tests/trainer/test_checks.py b/tests/trainer/test_checks.py deleted file mode 100755 index 45155d67e65d79..00000000000000 --- a/tests/trainer/test_checks.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest - -import tests.base.utils as tutils -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate - -# TODO: add matching messages - - -def test_wrong_train_setting(tmpdir): - """ - * Test that an error is thrown when no `training_dataloader()` is defined - * Test that an error is thrown when no `training_step()` is defined - """ - tutils.reset_seed() - hparams = tutils.get_default_hparams() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.train_dataloader = None - trainer.fit(model) - - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.training_step = None - trainer.fit(model) - - -def test_wrong_configure_optimizers(tmpdir): - """ Test that an error is thrown when no `configure_optimizers()` is defined """ - tutils.reset_seed() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(tutils.get_default_hparams()) - model.configure_optimizers = None - trainer.fit(model) - - -def test_wrong_validation_settings(tmpdir): - """ Test the following cases related to validation configuration of model: - * error if `val_dataloader()` is overriden but `validation_step()` is not - * if both `val_dataloader()` and `validation_step()` is overriden, - throw warning if `val_epoch_end()` is not defined - * error if `validation_step()` is overriden but `val_dataloader()` is not - """ - tutils.reset_seed() - hparams = tutils.get_default_hparams() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - - # check val_dataloader -> val_step - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.validation_step = None - trainer.fit(model) - - # check val_dataloader + val_step -> val_epoch_end - with pytest.warns(RuntimeWarning): - model = EvalModelTemplate(hparams) - model.validation_epoch_end = None - trainer.fit(model) - - # check val_step -> val_dataloader - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.val_dataloader = None - trainer.fit(model) - - -def test_wrong_test_settigs(tmpdir): - """ Test the following cases related to test configuration of model: - * error if `test_dataloader()` is overriden but `test_step()` is not - * if both `test_dataloader()` and `test_step()` is overriden, - throw warning if `test_epoch_end()` is not defined - * error if `test_step()` is overriden but `test_dataloader()` is not - """ - hparams = tutils.get_default_hparams() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - - # ---------------- - # if have test_dataloader should have test_step - # ---------------- - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.test_step = None - trainer.fit(model) - - # ---------------- - # if have test_dataloader and test_step recommend test_epoch_end - # ---------------- - with pytest.warns(RuntimeWarning): - model = EvalModelTemplate(hparams) - model.test_epoch_end = None - trainer.test(model) - - # ---------------- - # if have test_step and NO test_dataloader passed in tell user to pass test_dataloader - # ---------------- - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.test_dataloader = lambda: None - trainer.test(model) - - # ---------------- - # if have test_dataloader and NO test_step tell user to implement test_step - # ---------------- - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.test_dataloader = lambda: None - model.test_step = None - trainer.test(model, test_dataloaders=model.dataloader(train=False)) - - # ---------------- - # if have test_dataloader and test_step but no test_epoch_end warn user - # ---------------- - with pytest.warns(RuntimeWarning): - model = EvalModelTemplate(hparams) - model.test_dataloader = lambda: None - model.test_epoch_end = None - trainer.test(model, test_dataloaders=model.dataloader(train=False)) diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py new file mode 100644 index 00000000000000..9fccd9b36440ae --- /dev/null +++ b/tests/trainer/test_config_validator.py @@ -0,0 +1,149 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset + + +def test_wrong_train_setting(tmpdir): + """ + * Test that an error is thrown when no `train_dataloader()` is defined + * Test that an error is thrown when no `training_step()` is defined + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): + model = BoringModel() + model.train_dataloader = None + trainer.fit(model) + + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): + model = BoringModel() + model.training_step = None + trainer.fit(model) + + +def test_wrong_configure_optimizers(tmpdir): + """ Test that an error is thrown when no `configure_optimizers()` is defined """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): + model = BoringModel() + model.configure_optimizers = None + trainer.fit(model) + + +def test_fit_val_loop_config(tmpdir): + """" + When either val loop or val data are missing raise warning + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # no val data has val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() + model.validation_step = None + trainer.fit(model) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() + model.val_dataloader = None + trainer.fit(model) + + +def test_test_loop_config(tmpdir): + """" + When either test loop or test data are missing + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has test loop but no test data + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): + model = BoringModel() + model.test_dataloader = None + trainer.test(model) + + # has test data but no test loop + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): + model = BoringModel() + model.test_step = None + trainer.test(model) + + +def test_val_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = BoringModel() + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = BoringModel() + model.validation_step = None + trainer.validate(model) + + +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict_dataloader(self): + return self._dataloaders + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) diff --git a/tests/trainer/test_correct_freq_accumulation.py b/tests/trainer/test_correct_freq_accumulation.py new file mode 100644 index 00000000000000..77cc80407270ff --- /dev/null +++ b/tests/trainer/test_correct_freq_accumulation.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests to ensure that the training loop works with a dict +""" +import os +from unittest import mock + +from pytorch_lightning import Trainer +from tests.base.model_template import EvalModelTemplate + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_training_step_scalar(tmpdir): + """ + Tests that only training_step can be used + """ + + model = EvalModelTemplate() + model.validation_step = None + model.test_step = None + model.training_step_end = None + model.training_epoch_end = None + model.validation_step = model.validation_step__dp + model.validation_step_end = None + model.validation_epoch_end = None + model.test_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + # epoch 0 + assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 0 + assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 1 + assert trainer.dev_debugger.logged_metrics[2]['global_step'] == 1 + + # epoch 1 + assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2 + assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3 + assert trainer.dev_debugger.logged_metrics[5]['global_step'] == 3 diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py new file mode 100644 index 00000000000000..ec7f020faa4c31 --- /dev/null +++ b/tests/trainer/test_data_loading.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from torch.utils.data import DataLoader +from torch.utils.data.sampler import BatchSampler, SequentialSampler + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf + + +class IndexedRandomDataset(RandomDataset): + + def __getitem__(self, index): + return self.data[index] + + +class CustomDataLoader(DataLoader): + + def __init__(self, num_features, dataset, *args, **kwargs): + self.num_features = num_features + super().__init__(dataset, *args, **kwargs) + + +class FailureCustomDataLoader(DataLoader): + + def __init__(self, num_features, dataset, *args, **kwargs): + super().__init__(dataset, *args, **kwargs) + + +class CustomBatchSampler(BatchSampler): + pass + + +class TestModel(BoringModel): + + def __init__(self, numbers_test_dataloaders, save_preds_on_dl_idx, mode): + super().__init__() + self._numbers_test_dataloaders = numbers_test_dataloaders + self._save_preds_on_dl_idx = save_preds_on_dl_idx + self._mode = mode + + def test_step(self, batch, batch_idx, dataloader_idx=None): + return super().test_step(batch, batch_idx) + + def create_dataset(self): + dataset = IndexedRandomDataset(32, 64) + batch_sampler = None + batch_size = 2 + if self._mode == 2: + batch_size = 1 + batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True) + dataloader_cls = CustomDataLoader + else: + dataloader_cls = FailureCustomDataLoader + return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler) + + def test_dataloader(self): + return [self.create_dataset()] * self._numbers_test_dataloaders + + +def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode): + num_processes = 2 + limit_test_batches = 2 + trainer_args = { + "default_root_dir": tmpdir, + "limit_test_batches": limit_test_batches, + "accelerator": accelerator, + } + + if accelerator == "ddp_cpu": + trainer_args["num_processes"] = num_processes + else: + trainer_args["gpus"] = gpus + + model = TestModel(num_dl_idx, save_preds_on_dl_idx, mode) + model.test_epoch_end = None + + trainer = Trainer(**trainer_args) + if mode == 1: + match = "DistributedSampler within" + with pytest.raises(MisconfigurationException, match=match): + trainer.test(model) + else: + trainer.test(model) + + +@RunIf(min_gpus=2, special=True) +@pytest.mark.parametrize("mode", [1, 2]) +def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): + check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 92704a9040a9ec..505af173b79108 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1,208 +1,470 @@ -import platform +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock +from unittest.mock import patch import pytest import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Subset - -import tests.base.utils as tutils -from pytorch_lightning import Trainer +from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import SequentialSampler + +import tests.helpers.pipelines as tpipes +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.runif import RunIf + + +def test_fit_train_loader_only(tmpdir): + model = EvalModelTemplate() + train_dataloader = model.train_dataloader() + + model.train_dataloader = None + model.val_dataloader = None + model.test_dataloader = None + + model.validation_step = None + model.validation_epoch_end = None + + model.test_step = None + model.test_epoch_end = None + + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, train_dataloader=train_dataloader) + + +def test_fit_val_loader_only(tmpdir): + model = EvalModelTemplate() + train_dataloader = model.train_dataloader() + val_dataloader = model.val_dataloader() + + model.train_dataloader = None + model.val_dataloader = None + model.test_dataloader = None + + model.test_step = None + model.test_epoch_end = None + + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader) @pytest.mark.parametrize("dataloader_options", [ - dict(train_percent_check=-0.1), - dict(train_percent_check=1.1), - dict(val_check_interval=1.1), dict(val_check_interval=10000), ]) -def test_dataloader_config_errors(tmpdir, dataloader_options): - - model = EvalModelTemplate(tutils.get_default_hparams()) - - # fit model +def test_dataloader_config_errors_runtime(tmpdir, dataloader_options): + model = EvalModelTemplate() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, **dataloader_options, ) - with pytest.raises(ValueError): + # fit model trainer.fit(model) +@pytest.mark.parametrize( + "dataloader_options", [ + dict(limit_train_batches=-0.1), + dict(limit_train_batches=1.2), + dict(limit_val_batches=-0.1), + dict(limit_val_batches=1.2), + dict(limit_test_batches=-0.1), + dict(limit_test_batches=1.2), + dict(val_check_interval=-0.1), + dict(val_check_interval=1.2), + dict(overfit_batches=-0.1), + dict(overfit_batches=1.2), + ] +) +def test_dataloader_config_errors_init(tmpdir, dataloader_options): + with pytest.raises(MisconfigurationException, match='passed invalid value'): + Trainer( + default_root_dir=tmpdir, + max_epochs=1, + **dataloader_options, + ) + + def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=1.0, + limit_val_batches=0.1, + limit_train_batches=1.0, ) - result = trainer.fit(model) + trainer.fit(model) # verify training completed - assert result == 1 + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify there are 2 val loaders - assert len(trainer.val_dataloaders) == 2, \ - 'Multiple val_dataloaders not initiated properly' + assert len(trainer.val_dataloaders) == 2, 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: - tutils.run_prediction(dataloader, trainer.model) + tpipes.run_prediction_eval_model_template(trained_model=model, dataloader=dataloader) -def test_multiple_test_dataloader(tmpdir): - """Verify multiple test_dataloader.""" +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_multiple_eval_dataloader(tmpdir, ckpt_path): + """Verify multiple evaluation dataloaders.""" - model = EvalModelTemplate(tutils.get_default_hparams()) - model.test_dataloader = model.test_dataloader__multiple - model.test_step = model.test_step__multiple_dataloaders + class MultipleTestDataloaderModel(EvalModelTemplate): + + def test_dataloader(self): + return [self.dataloader(train=False), self.dataloader(train=False)] + + def test_step(self, *args, **kwargs): + return super().test_step__multiple_dataloaders(*args, **kwargs) + + def val_dataloader(self): + return self.test_dataloader() + + def validation_step(self, *args, **kwargs): + output = self.test_step(*args, **kwargs) + return {k.replace("test_", "val_"): v for k, v in output.items()} + + model = MultipleTestDataloaderModel() # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_val_batches=10, + limit_train_batches=100, ) trainer.fit(model) - trainer.test() + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path - # verify there are 2 test loaders - assert len(trainer.test_dataloaders) == 2, \ - 'Multiple test_dataloaders not initiated properly' + trainer.validate(ckpt_path=ckpt_path, verbose=False) + # verify there are 2 loaders + assert len(trainer.val_dataloaders) == 2 + # make sure predictions are good for each dl + for dataloader in trainer.val_dataloaders: + tpipes.run_prediction_eval_model_template(trainer.model, dataloader) - # make sure predictions are good for each test set + trainer.test(ckpt_path=ckpt_path, verbose=False) + assert len(trainer.test_dataloaders) == 2 for dataloader in trainer.test_dataloaders: - tutils.run_prediction(dataloader, trainer.model) - - # run the test method - trainer.test() + tpipes.run_prediction_eval_model_template(trainer.model, dataloader) def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ # only train passed to fit - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_val_batches=0.1, + limit_train_batches=0.2, ) fit_options = dict(train_dataloader=model.dataloader(train=True)) - result = trainer.fit(model, **fit_options) + trainer.fit(model, **fit_options) + + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - assert result == 1 +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@pytest.mark.parametrize("n", (1, 2)) +def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): + """Verify that dataloaders can be passed.""" -def test_train_val_dataloaders_passed_to_fit(tmpdir): - """ Verify that train & val dataloader can be passed to fit """ + model = EvalModelTemplate() + if n == 1: + dataloaders = model.dataloader(train=False) + else: + dataloaders = [model.dataloader(train=False)] * 2 + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders - # train, val passed to fit - model = EvalModelTemplate(tutils.get_default_hparams()) + # train, multiple val and multiple test passed to fit trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_val_batches=0.1, + limit_train_batches=0.2, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), - val_dataloaders=model.dataloader(train=False)) + trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) - result = trainer.fit(model, **fit_options) - assert result == 1 - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert len(trainer.val_dataloaders) == n + + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) + trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) -def test_all_dataloaders_passed_to_fit(tmpdir): - """Verify train, val & test dataloader(s) can be passed to fit and test method""" + assert len(trainer.val_dataloaders) == n + assert len(trainer.test_dataloaders) == n - model = EvalModelTemplate(tutils.get_default_hparams()) - # train, val and test passed to fit +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (0.0, 0.0, 0.0), + (1.0, 1.0, 1.0), +]) +def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" + model = EvalModelTemplate() + model.train_dataloader = model.train_dataloader__infinite + model.val_dataloader = model.val_dataloader__infinite + model.test_dataloader = model.test_dataloader__infinite + trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), - val_dataloaders=model.dataloader(train=False)) - test_options = dict(test_dataloaders=model.dataloader(train=False)) - result = trainer.fit(model, **fit_options) - trainer.test(**test_options) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf')) + assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf')) - assert result == 1 - assert len(trainer.val_dataloaders) == 1, \ - f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' + trainer.test(ckpt_path=None) + assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf')) + + +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (0, 0, 0), + (10, 10, 10), +]) +def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + model = EvalModelTemplate() + model.train_dataloader = model.train_dataloader__infinite + model.val_dataloader = model.val_dataloader__infinite + model.test_dataloader = model.test_dataloader__infinite + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == limit_train_batches + assert trainer.num_val_batches[0] == limit_val_batches + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches[0] == limit_test_batches -def test_multiple_dataloaders_passed_to_fit(tmpdir): - """Verify that multiple val & test dataloaders can be passed to fit.""" - model = EvalModelTemplate(tutils.get_default_hparams()) +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (0.0, 0.0, 0.0), + (0, 0, 0.5), + (1.0, 1.0, 1.0), + (0.2, 0.4, 0.4), +]) +def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" + model = EvalModelTemplate() + model.val_dataloader = model.val_dataloader__multiple_mixed_length + model.test_dataloader = model.test_dataloader__multiple_mixed_length model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders - # train, multiple val and multiple test passed to fit + # train, multiple val and multiple test passed with percent_check trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), - val_dataloaders=[model.dataloader(train=False), - model.dataloader(train=False)]) - test_options = dict(test_dataloaders=[model.dataloader(train=False), - model.dataloader(train=False)]) + trainer.fit(model) + expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches) + expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders] + assert trainer.num_training_batches == expected_train_batches + assert trainer.num_val_batches == expected_val_batches + + trainer.test(ckpt_path=None) + expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders] + assert trainer.num_test_batches == expected_test_batches - trainer.fit(model, **fit_options) - trainer.test(**test_options) - assert len(trainer.val_dataloaders) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' +@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ + (0, 0, 0), + (1, 2, 3), + (1, 2, 1e50), +]) +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): + """Verify num_batches for train, val & test dataloaders passed with batch limit as number""" + model = EvalModelTemplate() + model.val_dataloader = model.val_dataloader__multiple_mixed_length + model.test_dataloader = model.test_dataloader__multiple_mixed_length + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders + + # train, multiple val and multiple test passed with percent_check + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, + ) + trainer.fit(model) + + # ------------------------------------------- + # MAKE SURE THE TRAINER SET THE CORRECT VALUES + # ------------------------------------------- + assert trainer.num_training_batches == limit_train_batches + assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders) + trainer.test(ckpt_path=None) + + # when the limit is greater than the number of test batches it should be the num in loaders + test_dataloader_lengths = [len(x) for x in model.test_dataloader()] + if limit_test_batches > 1e10: + assert trainer.num_test_batches == test_dataloader_lengths + else: + assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders) + + # ------------------------------------------- + # make sure we actually saw the expected num of batches + # ------------------------------------------- + num_val_dataloaders = len(model.val_dataloader()) + num_test_dataloaders = len(model.test_dataloader()) + if limit_train_batches > 0: + + # make sure val batches are as expected + assert len(trainer.dev_debugger.num_seen_val_check_batches) == num_val_dataloaders + for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items(): + assert num_batches == limit_val_batches + + # make sure test batches are as expected + assert len(trainer.dev_debugger.num_seen_test_check_batches) == num_test_dataloaders + for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items(): + if limit_test_batches > 1e10: + assert num_batches == test_dataloader_lengths[dataloader_idx] + else: + assert num_batches == limit_test_batches + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize('fast_dev_run', [True, 1, 3, -1, 'temp']) +def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): + """ + Verify num_batches for train, val & test dataloaders passed with fast_dev_run + """ + model = EvalModelTemplate() + model.val_dataloader = model.val_dataloader__multiple_mixed_length + model.test_dataloader = model.test_dataloader__multiple_mixed_length + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders -def test_mixing_of_dataloader_options(tmpdir): + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=2, + fast_dev_run=fast_dev_run, + ) + + if fast_dev_run == 'temp': + with pytest.raises(MisconfigurationException, match='either a bool or an int'): + Trainer(**trainer_options) + elif fast_dev_run == -1: + with pytest.raises(MisconfigurationException, match='should be >= 0'): + Trainer(**trainer_options) + else: + trainer = Trainer(**trainer_options) + + # fast_dev_run is set to True when it is 1 + if fast_dev_run == 1: + fast_dev_run = True + + assert trainer.fast_dev_run is fast_dev_run + + if fast_dev_run is True: + fast_dev_run = 1 + + assert trainer.limit_train_batches == fast_dev_run + assert trainer.limit_val_batches == fast_dev_run + assert trainer.limit_test_batches == fast_dev_run + assert trainer.num_sanity_val_steps == 0 + assert trainer.max_epochs == 1 + + trainer.fit(model) + assert not trainer.disable_validation + assert trainer.num_training_batches == fast_dev_run + assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders) + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders) + + # verify sanity check batches match as expected + num_val_dataloaders = len(model.val_dataloader()) + assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders + + +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_mixing_of_dataloader_options(tmpdir, ckpt_path): """Verify that dataloaders can be passed to fit""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_val_batches=0.1, + limit_train_batches=0.2, ) # fit model trainer = Trainer(**trainer_options) - results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert results + trainer.fit(model, val_dataloaders=model.dataloader(train=False)) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # fit model trainer = Trainer(**trainer_options) - results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert results - trainer.test(test_dataloaders=model.dataloader(train=False)) + trainer.fit(model, val_dataloaders=model.dataloader(train=False)) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -212,34 +474,34 @@ def test_mixing_of_dataloader_options(tmpdir): def test_train_inf_dataloader_error(tmpdir): """Test inf train data loader (e.g. IterableDataset)""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.train_dataloader = model.train_dataloader__infinite trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.fit(model) def test_val_inf_dataloader_error(tmpdir): """Test inf train data loader (e.g. IterableDataset)""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__infinite - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.fit(model) def test_test_inf_dataloader_error(tmpdir): """Test inf train data loader (e.g. IterableDataset)""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.test_dataloader = model.test_dataloader__infinite - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5) - with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): trainer.test(model) @@ -247,24 +509,24 @@ def test_test_inf_dataloader_error(tmpdir): def test_inf_train_dataloader(tmpdir, check_interval): """Test inf train data loader (e.g. IterableDataset)""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.train_dataloader = model.train_dataloader__infinite trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - train_check_interval=check_interval, + val_check_interval=check_interval, ) - result = trainer.fit(model) + trainer.fit(model) # verify training completed - assert result == 1 + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @pytest.mark.parametrize('check_interval', [1.0]) def test_inf_val_dataloader(tmpdir, check_interval): """Test inf val data loader (e.g. IterableDataset)""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__infinite # logger file to get meta @@ -273,35 +535,16 @@ def test_inf_val_dataloader(tmpdir, check_interval): max_epochs=1, val_check_interval=check_interval, ) - result = trainer.fit(model) - - # verify training completed - assert result == 1 - - -@pytest.mark.parametrize('check_interval', [50, 1.0]) -def test_inf_test_dataloader(tmpdir, check_interval): - """Test inf test data loader (e.g. IterableDataset)""" - - model = EvalModelTemplate(tutils.get_default_hparams()) - model.test_dataloader = model.test_dataloader__infinite - - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - test_check_interval=check_interval, - ) - result = trainer.fit(model) + trainer.fit(model) # verify training completed - assert result == 1 + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" def test_error_on_zero_len_dataloader(tmpdir): """ Test that error is raised if a zero-length dataloader is defined """ - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.train_dataloader = model.train_dataloader__zero_length # fit model @@ -309,76 +552,242 @@ def test_error_on_zero_len_dataloader(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - test_percent_check=0.5 + limit_train_batches=0.1, + limit_val_batches=0.1, + limit_test_batches=0.1, ) trainer.fit(model) -@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') -def test_warning_with_few_workers(tmpdir): +@RunIf(skip_windows=True) +@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific')) +@pytest.mark.parametrize('stage', ('train', 'test', 'val')) +@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) +def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ - model = EvalModelTemplate(tutils.get_default_hparams()) + model = BoringModel() - # logger file to get meta - trainer_options = dict( + train_dl = model.train_dataloader() + train_dl.num_workers = 0 + + val_dl = model.val_dataloader() + val_dl.num_workers = 0 + + trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 + limit_val_batches=0.1, + limit_train_batches=0.2, ) - fit_options = dict(train_dataloader=model.dataloader(train=True), - val_dataloaders=model.dataloader(train=False)) - test_options = dict(test_dataloaders=model.dataloader(train=False)) + with pytest.warns( + UserWarning, + match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' + ): + if stage == 'test': + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path + trainer.test(model, test_dataloaders=train_dl, ckpt_path=ckpt_path) + else: + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize('ckpt_path', (None, 'best', 'specific')) +@pytest.mark.parametrize('stage', ('train', 'test', 'val')) +@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) +def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): + """ Test that error is raised if dataloader with only a few workers is used """ - trainer = Trainer(**trainer_options) + model = EvalModelTemplate() + model.training_step = model.training_step__multiple_dataloaders + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders - # fit model - with pytest.warns(UserWarning, match='train'): - trainer.fit(model, **fit_options) + val_dl = model.dataloader(train=False) + val_dl.num_workers = 0 + + train_dl = model.dataloader(train=False) + train_dl.num_workers = 0 + + train_multi_dl = {'a': train_dl, 'b': train_dl} + val_multi_dl = [val_dl, val_dl] + test_multi_dl = [train_dl, train_dl] + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + + with pytest.warns( + UserWarning, + match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' + ): + if stage == 'test': + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path + trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) + else: + trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) - with pytest.warns(UserWarning, match='val'): - trainer.fit(model, **fit_options) - with pytest.warns(UserWarning, match='test'): - trainer.test(**test_options) +def test_warning_with_iterable_dataset_and_len(tmpdir): + """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ + model = EvalModelTemplate() + original_dataset = model.train_dataloader().dataset + class IterableWithLen(IterableDataset): -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') -def test_dataloader_reinit_for_subclass(): + def __iter__(self): + return iter(original_dataset) + + def __len__(self): + return len(original_dataset) + + dataloader = DataLoader(IterableWithLen(), batch_size=16) + assert has_len(dataloader) + assert has_iterable_dataset(dataloader) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + ) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.test(model, test_dataloaders=[dataloader]) + + +@RunIf(min_gpus=2) +def test_dataloader_reinit_for_subclass(tmpdir): class CustomDataLoader(torch.utils.data.DataLoader): - def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, collate_fn=None, - pin_memory=False, drop_last=False, timeout=0, - worker_init_fn=None, dummy_kwarg=None): - super().__init__(dataset, batch_size, shuffle, sampler, batch_sampler, - num_workers, collate_fn, pin_memory, drop_last, timeout, - worker_init_fn) + + def __init__( + self, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + dummy_kwarg=None, + **kwargs + ): + super().__init__( + dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, + timeout, worker_init_fn + ) self.dummy_kwarg = dummy_kwarg trainer = Trainer( gpus=[0, 1], num_nodes=1, - distributed_backend='ddp', + accelerator='ddp_spawn', + default_root_dir=tmpdir, ) class CustomDummyObj: sampler = None - result = trainer.auto_add_sampler(CustomDummyObj(), train=True) + result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" - result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True) + dataset = list(range(1000)) + result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True) + assert isinstance(result, torch.utils.data.DataLoader) + assert isinstance(result, CustomDataLoader) + assert hasattr(result, 'dummy_kwarg') + + # Shuffled DataLoader should also work + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') + class CustomSampler(torch.utils.data.Sampler): + pass + + # Should raise an error if existing sampler is being replaced + with pytest.raises(MisconfigurationException, match='DistributedSampler'): + trainer.auto_add_sampler( + CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True + ) + + +class DistribSamplerCallback(Callback): -@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs') -def test_batch_size_smaller_than_num_gpus(): + def on_train_start(self, trainer, pl_module): + train_sampler = trainer.train_dataloader.sampler + assert isinstance(train_sampler, DistributedSampler) + assert train_sampler.shuffle + + def on_validation_start(self, trainer, pl_module): + val_sampler = trainer.val_dataloaders[0].sampler + assert isinstance(val_sampler, DistributedSampler) + assert not val_sampler.shuffle + + def on_test_start(self, trainer, pl_module): + test_sampler = trainer.test_dataloaders[0].sampler + assert isinstance(test_sampler, DistributedSampler) + assert not test_sampler.shuffle + + +@RunIf(min_gpus=2, skip_windows=True) +def test_dataloader_distributed_sampler(tmpdir): + """ Test DistributedSampler and it's arguments for DDP backend """ + + model = EvalModelTemplate() + trainer = Trainer( + gpus=[0, 1], + num_nodes=1, + accelerator='ddp_spawn', + default_root_dir=tmpdir, + max_steps=1, + callbacks=[DistribSamplerCallback()], + ) + trainer.fit(model) + trainer.test(ckpt_path=None) + + +class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): + + def train_dataloader(self): + dataloader = super().train_dataloader() + dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True) + return DataLoader( + dataloader.dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, shuffle=False + ) + + +@RunIf(min_gpus=2, skip_windows=True) +def test_dataloader_distributed_sampler_already_attached(tmpdir): + """ Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """ + + model = ModelWithDataLoaderDistributedSampler() + trainer = Trainer( + gpus=[0, 1], + num_nodes=1, + accelerator='ddp_spawn', + default_root_dir=tmpdir, + max_steps=100, + callbacks=[DistribSamplerCallback()], + replace_sampler_ddp=True, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, "DDP Training failed" + + +@RunIf(min_gpus=3) +def test_batch_size_smaller_than_num_gpus(tmpdir): # we need at least 3 gpus for this test num_gpus = 3 batch_size = 3 @@ -406,22 +815,415 @@ def train_dataloader(self): dataset = Subset(dataloader.dataset, range(size)) dataloader = DataLoader( dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, drop_last=False, ) return dataloader - hparams = tutils.get_default_hparams() - hparams.batch_size = batch_size - model = CurrentTestModel(hparams) + hparams = EvalModelTemplate.get_default_hparams() + hparams['batch_size'] = batch_size + model = CurrentTestModel(**hparams) trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0, + limit_train_batches=0.1, + limit_val_batches=0, gpus=num_gpus, ) # we expect the reduction for the metrics also to happen on the last batch # where we will get fewer metrics than gpus - result = trainer.fit(model) - assert 1 == result + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ + pytest.param("min_size", 5), + pytest.param("max_size_cycle", 10), +]) +def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches): + """Integration test for multple train loaders""" + model = EvalModelTemplate() + + model.train_dataloader = model.train_dataloader__multiple_mapping + # todo: add also `train_dataloader__multiple_sequence` + model.training_step = model.training_step__multiple_dataloaders + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + multiple_trainloader_mode=multiple_trainloader_mode, + ) + + assert 1 == trainer.fit(model) + # verify the num_training_batches according to the multiple_trainloader_mode + assert num_training_batches == trainer.num_training_batches + + +@pytest.mark.parametrize('check_interval', [1.0]) +def test_val_dataloader_not_implemented_error(tmpdir, check_interval): + """Test not_implemented_error data loader (e.g. IterableDataset)""" + + model = EvalModelTemplate() + model.val_dataloader = model.val_dataloader__not_implemented_error + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=5, + max_epochs=1, + val_check_interval=check_interval, + ) + trainer.fit(model) + # verify training completed + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@pytest.mark.parametrize('check_interval', [50, 1.0]) +def test_train_dataloader_not_implemented_error(tmpdir, check_interval): + """Test not_implemented_error train data loader (e.g. IterableDataset)""" + + model = EvalModelTemplate() + model.train_dataloader = model.train_dataloader__not_implemented_error + model.val_dataloader = model.val_dataloader__not_implemented_error + + trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval) + trainer.fit(model) + # verify training completed + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +def test_train_dataloader_not_implemented_error_failed(tmpdir): + """Test not_implemented_error train data loader (e.g. IterableDataset)""" + model = EvalModelTemplate() + model.train_dataloader = model.train_dataloader__not_implemented_error + + trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5) + + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + trainer.fit(model) + + +def test_val_dataloader_not_implemented_error_failed(tmpdir): + """Test not_implemented_error train data loader (e.g. IterableDataset)""" + model = EvalModelTemplate() + model.val_dataloader = model.val_dataloader__not_implemented_error + + trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5) + + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + trainer.fit(model) + + +def test_test_dataloader_not_implemented_error_failed(tmpdir): + """Test not_implemented_error train data loader (e.g. IterableDataset)""" + model = EvalModelTemplate() + model.test_dataloader = model.test_dataloader__not_implemented_error + + trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5) + + with pytest.raises(MisconfigurationException, match='using an IterableDataset'): + trainer.test(model) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_load_only_once(tmpdir): + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + max_epochs=3, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert len(trainer.dev_debugger.val_dataloader_calls) == 1 + assert len(trainer.dev_debugger.test_dataloader_calls) == 0 + assert len(trainer.dev_debugger.train_dataloader_calls) == 1 + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ + 'val_dataloader', + 'train_dataloader', + ] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_load_only_once_val_interval(tmpdir): + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=10, + limit_val_batches=10, + val_check_interval=0.3, + reload_dataloaders_every_epoch=True, + max_epochs=3, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + trainer.test() + + assert len(trainer.dev_debugger.val_dataloader_calls) == 10 + assert len(trainer.dev_debugger.test_dataloader_calls) == 1 + assert len(trainer.dev_debugger.train_dataloader_calls) == 3 + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'val_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'val_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'val_dataloader', + 'val_dataloader', + 'test_dataloader', + ] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_load_only_once_no_sanity_check(tmpdir): + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + num_sanity_val_steps=0, + max_epochs=3, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert len(trainer.dev_debugger.val_dataloader_calls) == 1 + assert len(trainer.dev_debugger.test_dataloader_calls) == 0 + assert len(trainer.dev_debugger.train_dataloader_calls) == 1 + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ + 'train_dataloader', + 'val_dataloader', + ] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_load_every_epoch(tmpdir): + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + reload_dataloaders_every_epoch=True, + max_epochs=3, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + trainer.test() + + assert len(trainer.dev_debugger.val_dataloader_calls) == 4 + assert len(trainer.dev_debugger.train_dataloader_calls) == 3 + assert len(trainer.dev_debugger.test_dataloader_calls) == 1 + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'test_dataloader', + ] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + num_sanity_val_steps=0, + reload_dataloaders_every_epoch=True, + max_epochs=3, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + trainer.test() + + assert len(trainer.dev_debugger.val_dataloader_calls) == 3 + assert len(trainer.dev_debugger.train_dataloader_calls) == 3 + assert len(trainer.dev_debugger.test_dataloader_calls) == 1 + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ + 'train_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'test_dataloader', + ] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_dataloaders_load_only_once_passed_loaders(tmpdir): + + model = EvalModelTemplate() + train_loader = model.train_dataloader() + model.train_dataloader = None + val_loader = model.val_dataloader() + model.val_dataloader = None + test_loader = model.test_dataloader() + model.test_dataloader = None + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + max_epochs=3, + ) + trainer.fit(model, train_loader, val_loader) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + trainer.test(test_dataloaders=test_loader) + + assert len(trainer.dev_debugger.val_dataloader_calls) == 1 + assert len(trainer.dev_debugger.test_dataloader_calls) == 1 + assert len(trainer.dev_debugger.train_dataloader_calls) == 1 + + # verify the sequence + calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ + 'val_dataloader', + 'train_dataloader', + ] + for call, expected in zip(calls, expected_sequence): + assert call['name'] == expected + + +def test_replace_sampler_with_multiprocessing_context(tmpdir): + """ + This test verifies that replace_sampler conserves multiprocessing context + """ + train = RandomDataset(32, 64) + context = 'spawn' + train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) + trainer = Trainer( + max_epochs=1, + progress_bar_refresh_rate=20, + overfit_batches=5, + ) + + new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) + assert (new_data_loader.multiprocessing_context == train.multiprocessing_context) + + +def test_request_dataloader(tmpdir): + """ + This test asserts dataloader can be modified and properly set to the trainer. + """ + + class DataLoaderWrapper: + + def __init__(self, loader): + self.loader = loader + self._iter = iter(self.loader) + + def __iter__(self): + self._iter = iter(self.loader) + return self._iter + + def __next__(self): + return next(self._iter) + + class DataLoaderFunc: + + def __init__(self, loader): + self.loader = loader + + def __call__(self): + return self.loader + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_train_dataloader_called = False + self.on_train_batch_start_called = False + self.on_val_dataloader_called = False + self.on_val_batch_start_called = False + + def on_train_dataloader(self) -> None: + loader = self.train_dataloader() + self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_train_dataloader_called = True + + def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) + self.on_train_batch_start_called = True + + def on_val_dataloader(self) -> None: + loader = self.val_dataloader() + self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_val_dataloader_called = True + + def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper) + self.on_val_batch_start_called = True + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + model = TestModel() + trainer.fit(model) + trainer.test(model) + assert model.on_train_dataloader_called + assert model.on_train_batch_start_called + assert model.on_val_dataloader_called + assert model.on_val_batch_start_called diff --git a/tests/trainer/test_evaluation_loop.py b/tests/trainer/test_evaluation_loop.py new file mode 100644 index 00000000000000..3fe58afde73417 --- /dev/null +++ b/tests/trainer/test_evaluation_loop.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook") +def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir): + """ + Tests that `call_on_evaluation_epoch_end_hook` is called + for `on_validation_epoch_end` and `on_test_epoch_end` hooks + """ + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + weights_summary=None, + ) + + trainer.fit(model) + # sanity + 2 epochs + assert eval_epoch_end_mock.call_count == 3 + + trainer.test() + # sanity + 2 epochs + called once for test + assert eval_epoch_end_mock.call_count == 4 diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py old mode 100755 new mode 100644 index ce9d3d3b1b0f3d..44510eb16184df --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -1,42 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy + import pytest import torch -import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.helpers import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.simple_models import ClassificationModel def test_error_on_more_than_1_optimizer(tmpdir): """ Check that error is thrown when more than 1 optimizer is passed """ - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__multiple_schedulers # logger file to get meta trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1 + default_root_dir=tmpdir, + max_epochs=1, ) with pytest.raises(MisconfigurationException): - trainer.lr_find(model) + trainer.tuner.lr_find(model) def test_model_reset_correctly(tmpdir): """ Check that model weights are correctly reset after lr_find() """ - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() # logger file to get meta trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1 + default_root_dir=tmpdir, + max_epochs=1, ) - before_state_dict = model.state_dict() + before_state_dict = deepcopy(model.state_dict()) - _ = trainer.lr_find(model, num_training=5) + _ = trainer.tuner.lr_find(model, num_training=5) after_state_dict = model.state_dict() @@ -44,27 +62,28 @@ def test_model_reset_correctly(tmpdir): assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ 'Model was not reset correctly after learning rate finder' + assert not os.path.exists(tmpdir / 'lr_find_temp_model.ckpt') + def test_trainer_reset_correctly(tmpdir): """ Check that all trainer parameters are reset correctly after lr_find() """ - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() # logger file to get meta trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1 + default_root_dir=tmpdir, + max_epochs=1, ) - changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find', - 'progress_bar_refresh_rate', 'early_stop_callback', - 'accumulate_grad_batches', 'enable_early_stop', - 'checkpoint_callback'] + changed_attributes = [ + 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' + ] attributes_before = {} for ca in changed_attributes: attributes_before[ca] = getattr(trainer, ca) - _ = trainer.lr_find(model, num_training=5) + _ = trainer.tuner.lr_find(model, num_training=5) attributes_after = {} for ca in changed_attributes: @@ -74,62 +93,205 @@ def test_trainer_reset_correctly(tmpdir): assert attributes_before[key] == attributes_after[key], \ f'Attribute {key} was not reset correctly after learning rate finder' + assert model.trainer == trainer -def test_trainer_arg_bool(tmpdir): - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - before_lr = hparams.learning_rate +@pytest.mark.parametrize('use_hparams', [False, True]) +def test_trainer_arg_bool(tmpdir, use_hparams): + """ Test that setting trainer arg to bool works """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + before_lr = hparams.get('learning_rate') + if use_hparams: + del model.learning_rate + model.configure_optimizers = model.configure_optimizers__lr_from_hparams # logger file to get meta trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1, - auto_lr_find=True + default_root_dir=tmpdir, + max_epochs=2, + auto_lr_find=True, ) - trainer.fit(model) - after_lr = model.hparams.learning_rate + trainer.tune(model) + if use_hparams: + after_lr = model.hparams.learning_rate + else: + after_lr = model.learning_rate + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + + +@pytest.mark.parametrize('use_hparams', [False, True]) +def test_trainer_arg_str(tmpdir, use_hparams): + """ Test that setting trainer arg to string works """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + model.my_fancy_lr = 1.0 # update with non-standard field + model.hparams['my_fancy_lr'] = 1.0 + before_lr = model.my_fancy_lr + if use_hparams: + del model.my_fancy_lr + model.configure_optimizers = model.configure_optimizers__lr_from_hparams + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + auto_lr_find='my_fancy_lr', + ) + + trainer.tune(model) + if use_hparams: + after_lr = model.hparams.my_fancy_lr + else: + after_lr = model.my_fancy_lr + assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' -def test_trainer_arg_str(tmpdir): +@pytest.mark.parametrize('optimizer', ['Adam', 'Adagrad']) +def test_call_to_trainer_method(tmpdir, optimizer): + """ Test that directly calling the trainer method works """ - hparams = tutils.get_default_hparams() - hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field - model = EvalModelTemplate(hparams) + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + if optimizer == 'adagrad': + model.configure_optimizers = model.configure_optimizers__adagrad - before_lr = hparams.my_fancy_lr + before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1, - auto_lr_find='my_fancy_lr' + default_root_dir=tmpdir, + max_epochs=2, ) - trainer.fit(model) - after_lr = model.hparams.my_fancy_lr + lrfinder = trainer.tuner.lr_find(model, mode='linear') + after_lr = lrfinder.suggestion() + model.learning_rate = after_lr + trainer.tune(model) + assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' -def test_call_to_trainer_method(tmpdir): +def test_datamodule_parameter(tmpdir): + """ Test that the datamodule parameter works """ - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) + dm = ClassifDataModule() + model = ClassificationModel() - before_lr = hparams.learning_rate + before_lr = model.lr # logger file to get meta trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1, + default_root_dir=tmpdir, + max_epochs=2, ) - lrfinder = trainer.lr_find(model, mode='linear') + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) after_lr = lrfinder.suggestion() - model.hparams.learning_rate = after_lr - trainer.fit(model) + model.lr = after_lr assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' + + +def test_accumulation_and_early_stopping(tmpdir): + """ Test that early stopping of learning rate finder works, and that + accumulation also works for this feature """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_lr = hparams.get('learning_rate') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + accumulate_grad_batches=2, + ) + + lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) + after_lr = lrfinder.suggestion() + + expected_num_lrs = 100 + expected_batch_idx = 200 - 1 + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + assert len(lrfinder.results['lr']) == expected_num_lrs, \ + 'Early stopping for learning rate finder did not work' + assert lrfinder._total_batch_idx == expected_batch_idx, \ + 'Accumulation parameter did not work' + + +def test_suggestion_parameters_work(tmpdir): + """ Test that default skipping does not alter results in basic case """ + + dm = ClassifDataModule() + model = ClassificationModel() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + ) + + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) + lr1 = lrfinder.suggestion(skip_begin=10) # default + lr2 = lrfinder.suggestion(skip_begin=150) # way too high, should have an impact + + assert lr1 != lr2, 'Skipping parameter did not influence learning rate' + + +def test_suggestion_with_non_finite_values(tmpdir): + """ Test that non-finite values does not alter results """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + ) + + lrfinder = trainer.tuner.lr_find(model) + before_lr = lrfinder.suggestion() + lrfinder.results['loss'][-1] = float('nan') + after_lr = lrfinder.suggestion() + + assert before_lr == after_lr, \ + 'Learning rate was altered because of non-finite loss values' + + +def test_lr_finder_fails_fast_on_bad_config(tmpdir): + """ Test that tune fails if the model does not have a lr BEFORE running lr find """ + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True) + with pytest.raises(MisconfigurationException, match='should have one of these fields'): + trainer.tune(BoringModel()) + + +def test_lr_find_with_bs_scale(tmpdir): + """ Test that lr_find runs with batch_size_scaling """ + + class BoringModelTune(BoringModel): + + def __init__(self, learning_rate=0.1, batch_size=2): + super().__init__() + self.save_hyperparameters() + + model = BoringModelTune() + before_lr = model.hparams.learning_rate + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + ) + bs = trainer.tuner.scale_batch_size(model) + lr = trainer.tuner.lr_find(model).suggestion() + + assert lr != before_lr + assert isinstance(bs, int) diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py deleted file mode 100644 index 665ba3cdfbc69c..00000000000000 --- a/tests/trainer/test_optimizers.py +++ /dev/null @@ -1,240 +0,0 @@ -import pytest -import torch - -import tests.base.utils as tutils -from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate - - -def test_optimizer_with_scheduling(tmpdir): - """ Verify that learning rate scheduling is working """ - - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - model.configure_optimizers = model.configure_optimizers__single_scheduler - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - results = trainer.fit(model) - assert results == 1 - - init_lr = hparams.learning_rate - adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] - - assert len(trainer.lr_schedulers) == 1, \ - 'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) - - assert all(a == adjusted_lr[0] for a in adjusted_lr), \ - 'Lr not equally adjusted for all param groups' - adjusted_lr = adjusted_lr[0] - - assert init_lr * 0.1 == adjusted_lr, \ - 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr) - - -def test_multi_optimizer_with_scheduling(tmpdir): - """ Verify that learning rate scheduling is working """ - - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - model.configure_optimizers = model.configure_optimizers__multiple_schedulers - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - results = trainer.fit(model) - assert results == 1 - - init_lr = hparams.learning_rate - adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] - adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] - - assert len(trainer.lr_schedulers) == 2, \ - 'all lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) - - assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \ - 'Lr not equally adjusted for all param groups for optimizer 1' - adjusted_lr1 = adjusted_lr1[0] - - assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \ - 'Lr not equally adjusted for all param groups for optimizer 2' - adjusted_lr2 = adjusted_lr2[0] - - assert init_lr * 0.1 == adjusted_lr1 and init_lr * 0.1 == adjusted_lr2, \ - 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr1) - - -def test_multi_optimizer_with_scheduling_stepping(tmpdir): - - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - model.configure_optimizers = model.configure_optimizers__multiple_schedulers - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - results = trainer.fit(model) - assert results == 1 - - init_lr = hparams.learning_rate - adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] - adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] - - assert len(trainer.lr_schedulers) == 2, \ - 'all lr scheduler not initialized properly' - - assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \ - 'lr not equally adjusted for all param groups for optimizer 1' - adjusted_lr1 = adjusted_lr1[0] - - assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \ - 'lr not equally adjusted for all param groups for optimizer 2' - adjusted_lr2 = adjusted_lr2[0] - - # Called ones after end of epoch - assert init_lr * 0.1 ** 1 == adjusted_lr1, \ - 'lr for optimizer 1 not adjusted correctly' - # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times - assert init_lr * 0.1 == adjusted_lr2, \ - 'lr for optimizer 2 not adjusted correctly' - - -def test_reduce_lr_on_plateau_scheduling(tmpdir): - - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - model.configure_optimizers = model.configure_optimizers__reduce_lr_on_plateau - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - results = trainer.fit(model) - assert results == 1 - - assert trainer.lr_schedulers[0] == \ - dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss', - interval='epoch', frequency=1, reduce_on_plateau=True), \ - 'lr schduler was not correctly converted to dict' - - -def test_optimizer_return_options(): - - trainer = Trainer() - model = EvalModelTemplate(tutils.get_default_hparams()) - - # single optimizer - opt_a = torch.optim.Adam(model.parameters(), lr=0.002) - opt_b = torch.optim.SGD(model.parameters(), lr=0.002) - scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) - scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) - - # single optimizer - model.configure_optimizers = lambda: opt_a - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0 - - # opt tuple - model.configure_optimizers = lambda: (opt_a, opt_b) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b - assert len(lr_sched) == 0 and len(freq) == 0 - - # opt list - model.configure_optimizers = lambda: [opt_a, opt_b] - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b - assert len(lr_sched) == 0 and len(freq) == 0 - - # opt tuple of 2 lists - model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 - assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') - - # opt single dictionary - model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 - assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') - - # opt multiple dictionaries with frequencies - model.configure_optimizers = lambda: ( - {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1}, - {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, - ) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 - assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') - assert freq == [1, 5] - - -def test_none_optimizer_warning(): - - trainer = Trainer() - - model = EvalModelTemplate(tutils.get_default_hparams()) - model.configure_optimizers = lambda: None - - with pytest.warns(UserWarning, match='will run with no optimizer'): - _, __, ___ = trainer.init_optimizers(model) - - -def test_none_optimizer(tmpdir): - - hparams = tutils.get_default_hparams() - model = EvalModelTemplate(hparams) - model.configure_optimizers = model.configure_optimizers__empty - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - result = trainer.fit(model) - - # verify training completed - assert result == 1 - - -def test_configure_optimizer_from_dict(tmpdir): - """Tests if `configure_optimizer` method could return a dictionary with `optimizer` field only.""" - - class CurrentModel(EvalModelTemplate): - def configure_optimizers(self): - config = { - 'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03) - } - return config - - hparams = tutils.get_default_hparams() - model = CurrentModel(hparams) - - # fit model - trainer = Trainer(default_save_path=tmpdir, max_epochs=1) - result = trainer.fit(model) - assert result == 1 diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py new file mode 100644 index 00000000000000..d2257a84f74db0 --- /dev/null +++ b/tests/trainer/test_states.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel + + +def test_initialize_state(tmpdir): + """ Tests that state is INITIALIZE after Trainer creation """ + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.state == TrainerState.INITIALIZING + + +@pytest.mark.parametrize( + "extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), + ] +) +def test_trainer_state_while_running(tmpdir, extra_params): + trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True) + + class TestModel(BoringModel): + + def __init__(self, expected_state): + super().__init__() + self.expected_state = expected_state + self.lr = 0.1 + + def on_batch_start(self, *_): + assert self.trainer.state == self.expected_state + + def on_train_batch_start(self, *_): + assert self.trainer.training + + def on_sanity_check_start(self, *_): + assert self.trainer.sanity_checking + + def on_validation_batch_start(self, *_): + assert self.trainer.validating or self.trainer.sanity_checking + + def on_test_batch_start(self, *_): + assert self.trainer.testing + + model = TestModel(TrainerState.TUNING) + trainer.tune(model) + assert trainer.state == TrainerState.FINISHED + + model = TestModel(TrainerState.FITTING) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED + + model = TestModel(TrainerState.TESTING) + trainer.test(model) + assert trainer.state == TrainerState.FINISHED + + +@pytest.mark.parametrize( + "extra_params", [ + pytest.param(dict(fast_dev_run=True), id='Fast-Run'), + pytest.param(dict(max_steps=1), id='Single-Step'), + ] +) +def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): + """ Tests that state is set to INTERRUPTED on KeyboardInterrupt """ + model = BoringModel() + + class InterruptCallback(Callback): + + def on_batch_start(self, trainer, pl_module): + raise KeyboardInterrupt + + trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmpdir, **extra_params) + + trainer.fit(model) + assert trainer.state == TrainerState.INTERRUPTED diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py new file mode 100644 index 00000000000000..30b984dc896bed --- /dev/null +++ b/tests/trainer/test_supporters.py @@ -0,0 +1,239 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Sequence + +import pytest +import torch +from torch.utils.data import TensorDataset + +from pytorch_lightning.trainer.supporters import ( + _nested_calc_num_data, + CombinedDataset, + CombinedLoader, + CombinedLoaderIterator, + CycleIterator, + TensorRunningAccum, +) +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_tensor_running_accum_reset(): + """ Test that reset would set all attributes to the initialization state """ + + window_length = 10 + + accum = TensorRunningAccum(window_length=window_length) + assert accum.last() is None + assert accum.mean() is None + + accum.append(torch.tensor(1.5)) + assert accum.last() == torch.tensor(1.5) + assert accum.mean() == torch.tensor(1.5) + + accum.reset() + assert accum.window_length == window_length + assert accum.memory is None + assert accum.current_idx == 0 + assert accum.last_idx is None + assert not accum.rotated + + +def test_cycle_iterator(): + """Test the cycling function of `CycleIterator`""" + iterator = CycleIterator(range(100), 1000) + assert len(iterator) == 1000 + for idx, item in enumerate(iterator): + assert item < 100 + + assert idx == len(iterator) - 1 + + +def test_none_length_cycle_iterator(): + """Test the infinite cycling function of `CycleIterator`""" + iterator = CycleIterator(range(100)) + assert iterator.__len__() == float("inf") + + # test infinite loop + for idx, item in enumerate(iterator): + if idx == 1000: + break + assert item == 0 + + +@pytest.mark.parametrize( + ["dataset_1", "dataset_2"], + [ + ([list(range(10)), list(range(20))]), + ([range(10), range(20)]), + ([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]), + ([TensorDataset(torch.randn(10, 3, 2)), + TensorDataset(torch.randn(20, 5, 6))]), + ], +) +def test_combined_dataset(dataset_1, dataset_2): + """Verify the length of the CombinedDataset""" + datasets = [dataset_1, dataset_2] + combined_dataset = CombinedDataset(datasets) + + assert combined_dataset.max_len == 20 + assert combined_dataset.min_len == len(combined_dataset) == 10 + + +def test_combined_dataset_length_mode_error(): + with pytest.raises(MisconfigurationException, match="Invalid Mode"): + CombinedDataset._calc_num_data([range(10)], "test") + + +def test_combined_loader_iterator_dict_min_size(): + """Test `CombinedLoaderIterator` given mapping loaders""" + loaders = { + "a": torch.utils.data.DataLoader(range(10), batch_size=4), + "b": torch.utils.data.DataLoader(range(20), batch_size=5), + } + + combined_iter = CombinedLoaderIterator(loaders) + + for idx, item in enumerate(combined_iter): + assert isinstance(item, dict) + assert len(item) == 2 + assert "a" in item and "b" in item + + assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1 + + +def test_combined_loader_init_mode_error(): + """Test the ValueError when constructing `CombinedLoader`""" + with pytest.raises(MisconfigurationException, match="selected unsupported mode"): + CombinedLoader([range(10)], "testtt") + + +def test_combined_loader_loader_type_error(): + """Test the ValueError when wrapping the loaders""" + with pytest.raises(ValueError, match="Invalid Datatype"): + CombinedLoader(None, "max_size_cycle") + + +def test_combined_loader_calc_length_mode_error(): + """Test the ValueError when calculating the number of batches""" + with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): + CombinedLoader._calc_num_batches(None) + + +def test_combined_loader_dict_min_size(): + """Test `CombinedLoader` of mode 'min_size' given mapping loaders""" + loaders = { + "a": torch.utils.data.DataLoader(range(10), batch_size=4), + "b": torch.utils.data.DataLoader(range(20), batch_size=5), + } + + combined_loader = CombinedLoader(loaders, "min_size") + + assert len(combined_loader) == min([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert "a" in item and "b" in item + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_dict_max_size_cycle(): + """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders""" + loaders = { + "a": torch.utils.data.DataLoader(range(10), batch_size=4), + "b": torch.utils.data.DataLoader(range(20), batch_size=5), + } + + combined_loader = CombinedLoader(loaders, "max_size_cycle") + + assert len(combined_loader) == max([len(v) for v in loaders.values()]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, dict) + assert len(item) == 2 + assert "a" in item and "b" in item + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_sequence_min_size(): + """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" + loaders = [ + torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5), + ] + + combined_loader = CombinedLoader(loaders, "min_size") + + assert len(combined_loader) == min([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 + + +def test_combined_loader_sequence_max_size_cycle(): + """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" + loaders = [ + torch.utils.data.DataLoader(range(10), batch_size=4), + torch.utils.data.DataLoader(range(20), batch_size=5), + ] + + combined_loader = CombinedLoader(loaders, "max_size_cycle") + + assert len(combined_loader) == max([len(v) for v in loaders]) + + for idx, item in enumerate(combined_loader): + assert isinstance(item, Sequence) + assert len(item) == 2 + + assert idx == len(combined_loader) - 1 + + +@pytest.mark.parametrize( + ["input_data", "compute_func", "expected_length"], + [ + ([*range(10), list(range(1, 20))], min, 0), + ([*range(10), list(range(1, 20))], max, 19), + ([*range(10), {str(i): i + for i in range(1, 20)}], min, 0), + ([*range(10), {str(i): i + for i in range(1, 20)}], max, 19), + ({ + **{str(i): i + for i in range(10)}, "nested": {str(i): i + for i in range(1, 20)} + }, min, 0), + ({ + **{str(i): i + for i in range(10)}, "nested": {str(i): i + for i in range(1, 20)} + }, max, 19), + ({ + **{str(i): i + for i in range(10)}, "nested": list(range(20)) + }, min, 0), + ({ + **{str(i): i + for i in range(10)}, "nested": list(range(20)) + }, max, 19), + ], +) +def test_nested_calc_num_data(input_data, compute_func, expected_length): + calculated_length = _nested_calc_num_data(input_data, compute_func) + + assert calculated_length == expected_length diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8ab722d8886b28..ee93ca59eca768 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,204 +1,305 @@ -import glob +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math import os -import types +import pickle +import sys from argparse import Namespace +from copy import deepcopy +from pathlib import Path +from unittest.mock import ANY, call, patch +import cloudpickle import pytest import torch +from omegaconf import OmegaConf +from torch.optim import SGD +from torch.utils.data import DataLoader -import tests.base.utils as tutils -from pytorch_lightning import Callback, LightningModule -from pytorch_lightning import Trainer +import tests.helpers.utils as tutils +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.core.lightning import load_hparams_from_tags_csv +from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf -def test_model_pickle(tmpdir): - import pickle - - model = EvalModelTemplate(tutils.get_default_hparams()) - pickle.dumps(model) - - -def test_hparams_save_load(tmpdir): - model = EvalModelTemplate(vars(tutils.get_default_hparams())) - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - # fit model - result = trainer.fit(model) - assert result == 1 - - # try to load the model now - pretrained_model = tutils.load_model_from_checkpoint( - trainer.checkpoint_callback.dirpath, - module_class=EvalModelTemplate - ) - assert pretrained_model - - -def test_no_val_module(tmpdir): +@pytest.mark.parametrize("url_ckpt", [True, False]) +def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv("TORCH_HOME", str(tmpdir)) - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) # fit model - result = trainer.fit(model) + trainer.fit(model) # training complete - assert result == 1, 'amp + ddp model failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # save model - new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') + new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # assert ckpt has hparams ckpt = torch.load(new_weights_path) - assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints' + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), "hyper_parameters missing from checkpoints" - # won't load without hparams in the ckpt + # load new model + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, "hparams.yaml") + ckpt_path = ( + f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" + if url_ckpt else new_weights_path + ) model_2 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=new_weights_path, + checkpoint_path=ckpt_path, + hparams_file=hparams_path, ) model_2.eval() -def test_no_val_end_module(tmpdir): +@pytest.mark.parametrize("url_ckpt", [True, False]) +def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv("TORCH_HOME", tmpdir) - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) - result = trainer.fit(model) + trainer.fit(model) # traning complete - assert result == 1, 'amp + ddp model failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # save model - new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') + new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model - tags_path = tutils.get_data_path(logger, path_dir=tmpdir) - tags_path = os.path.join(tags_path, 'meta_tags.csv') + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, "hparams.yaml") + ckpt_path = ( + f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" + if url_ckpt else new_weights_path + ) model_2 = EvalModelTemplate.load_from_checkpoint( - checkpoint_path=new_weights_path, - tags_csv=tags_path + checkpoint_path=ckpt_path, + hparams_file=hparams_path, ) model_2.eval() -def test_gradient_accumulation_scheduling(tmpdir): - """ - Test grad accumulation by the freq of optimizer updates - """ +@pytest.mark.parametrize("url_ckpt", [True, False]) +def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt): + """Tests use case where trainer saves the model, and user loads it from tags independently.""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv("TORCH_HOME", tmpdir) - # test incorrect configs - with pytest.raises(IndexError): - assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6}) - assert Trainer(accumulate_grad_batches={-2: 3}) + model = EvalModelTemplate() + # Extra layer + model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim) - with pytest.raises(TypeError): - assert Trainer(accumulate_grad_batches={}) - assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]]) - assert Trainer(accumulate_grad_batches={1: 2, 3.: 4}) - assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) + # logger file to get meta + logger = tutils.get_default_logger(tmpdir) - # test optimizer call freq matches scheduler - def _optimizer_step(self, epoch, batch_idx, optimizer, - optimizer_idx, second_order_closure=None): - # only test the first 12 batches in epoch - if batch_idx < 12: - if epoch == 0: - # reset counter when starting epoch - if batch_idx == 0: - self.prev_called_batch_idx = 0 + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=logger, + callbacks=[ModelCheckpoint(dirpath=tmpdir)], + ) + trainer.fit(model) - # use this opportunity to test once - assert self.trainer.accumulate_grad_batches == 1 + # traning complete + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED - assert batch_idx == self.prev_called_batch_idx - self.prev_called_batch_idx += 1 + # save model + new_weights_path = os.path.join(tmpdir, "save_test.ckpt") + trainer.save_checkpoint(new_weights_path) - elif 1 <= epoch <= 2: - # reset counter when starting epoch - if batch_idx == 1: - self.prev_called_batch_idx = 1 + # load new model + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, "hparams.yaml") + ckpt_path = ( + f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" + if url_ckpt else new_weights_path + ) + + try: + EvalModelTemplate.load_from_checkpoint( + checkpoint_path=ckpt_path, + hparams_file=hparams_path, + ) + # todo: specify the possible exception + except Exception: + failed = True + else: + failed = False + + assert failed, "Model should not been loaded since the extra layer added." + + failed = False + try: + EvalModelTemplate.load_from_checkpoint( + checkpoint_path=ckpt_path, + hparams_file=hparams_path, + strict=False, + ) + # todo: specify the possible exception + except Exception: + failed = True + + assert not failed, "Model should be loaded due to strict=False." + + +@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3)) +def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batches): + with patch("torch.optim.SGD.zero_grad") as sgd_zero_grad: + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=20, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + accumulate_grad_batches=accumulate_grad_batches, + ) + trainer.fit(model) - # use this opportunity to test once - assert self.trainer.accumulate_grad_batches == 2 + assert sgd_zero_grad.call_count == math.ceil(trainer.limit_train_batches / accumulate_grad_batches) + + +@pytest.mark.parametrize( + ["accumulate_grad_batches", "limit_train_batches"], + [ + ({ + 1: 2, + 3: 4 + }, 1.0), + ({ + 1: 2, + 3: 4 + }, 0.5), # not to be divisible by accumulate_grad_batches on purpose + (3, 1.0), + (3, 0.8), # not to be divisible by accumulate_grad_batches on purpose + (4, 1.0), + (4, 0.7), # not to be divisible by accumulate_grad_batches on purpose + ], +) +def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches): + """ Verify optimizer.step() applied to last batch while grad accumulation """ + + class CurrentModel(BoringModel): + + def on_batch_start(self, *_): + self.on_train_batch_start_state_dict = self.state_dict() + + def on_batch_end(self, outputs, batch, batch_idx, *_): + self.on_train_batch_start_end_dict = self.state_dict() + for key in self.on_train_batch_start_end_dict.keys(): + equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key]) + if (batch_idx + 1) == self.trainer.num_training_batches: + assert equal + else: + assert not equal - assert batch_idx == self.prev_called_batch_idx - self.prev_called_batch_idx += 2 + model = CurrentModel() - else: - if batch_idx == 3: - self.prev_called_batch_idx = 3 + trainer = Trainer( + accumulate_grad_batches=accumulate_grad_batches, + max_epochs=2, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + limit_test_batches=0, + default_root_dir=tmpdir, + ) + + trainer.fit(model) - # use this opportunity to test once - assert self.trainer.accumulate_grad_batches == 4 - assert batch_idx == self.prev_called_batch_idx - self.prev_called_batch_idx += 3 +def test_loading_meta_tags(tmpdir): + """ test for backward compatibility to meta_tags.csv """ + tutils.reset_seed() - optimizer.step() + hparams = EvalModelTemplate.get_default_hparams() - # clear gradients - optimizer.zero_grad() + # save tags + logger = tutils.get_default_logger(tmpdir) + logger.log_hyperparams(Namespace(some_str="a_str", an_int=1, a_float=2.0)) + logger.log_hyperparams(hparams) + logger.save() - model = EvalModelTemplate(tutils.get_default_hparams()) - schedule = {1: 2, 3: 4} + # load hparams + path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) + hparams = load_hparams_from_yaml(hparams_path) - trainer = Trainer(accumulate_grad_batches=schedule, - train_percent_check=0.1, - val_percent_check=0.1, - max_epochs=2, - default_root_dir=tmpdir) + # save as legacy meta_tags.csv + tags_path = os.path.join(path_expt_dir, "meta_tags.csv") + save_hparams_to_tags_csv(tags_path, hparams) - # for the test - trainer.optimizer_step = _optimizer_step - model.prev_called_batch_idx = 0 + tags = load_hparams_from_tags_csv(tags_path) - trainer.fit(model) + assert hparams == tags -def test_loading_meta_tags(tmpdir): +def test_loading_yaml(tmpdir): + tutils.reset_seed() - hparams = tutils.get_default_hparams() + hparams = EvalModelTemplate.get_default_hparams() # save tags logger = tutils.get_default_logger(tmpdir) - logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0)) + logger.log_hyperparams(Namespace(some_str="a_str", an_int=1, a_float=2.0)) logger.log_hyperparams(hparams) logger.save() - # load tags + # load hparams path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) - tags_path = os.path.join(path_expt_dir, 'meta_tags.csv') - tags = load_hparams_from_tags_csv(tags_path) + hparams_path = os.path.join(path_expt_dir, "hparams.yaml") + tags = load_hparams_from_yaml(hparams_path) - assert tags.batch_size == 32 and tags.hidden_dim == 1000 + assert tags["batch_size"] == 32 and tags["hidden_dim"] == 1000 def test_dp_output_reduce(): @@ -212,224 +313,265 @@ def test_dp_output_reduce(): assert mixin.reduce_distributed_output(out, num_gpus=2) == out.mean() # when we have a dict of vals - out = { - 'a': out, - 'b': { - 'c': out - } - } + out = {"a": out, "b": {"c": out}} reduced = mixin.reduce_distributed_output(out, num_gpus=3) - assert reduced['a'] == out['a'] - assert reduced['b']['c'] == out['b']['c'] - - -@pytest.mark.parametrize(["save_top_k", "file_prefix", "expected_files"], [ - pytest.param(-1, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'}, - id="CASE K=-1 (all)"), - pytest.param(1, 'test_prefix_', {'test_prefix_epoch=4.ckpt'}, - id="CASE K=1 (2.5, epoch 4)"), - pytest.param(2, '', {'epoch=4.ckpt', 'epoch=2.ckpt'}, - id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"), - pytest.param(4, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'}, - id="CASE K=4 (save all 4 base)"), - pytest.param(3, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'}, - id="CASE K=3 (save the 2nd, 3rd, 4th model)"), -]) -def test_model_checkpoint_options(tmpdir, save_top_k, file_prefix, expected_files): + assert reduced["a"] == out["a"] + assert reduced["b"]["c"] == out["b"]["c"] + + +@pytest.mark.parametrize( + "save_top_k,save_last,expected_files", + [ + pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"), + pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"), + pytest.param(2, False, [f"epoch={i}.ckpt" for i in (2, 4)], id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"), + pytest.param(4, False, [f"epoch={i}.ckpt" for i in range(1, 5)], id="CASE K=4 (save all 4 base)"), + pytest.param(3, False, [f"epoch={i}.ckpt" for i in range(2, 5)], id="CASE K=3 (save the 2nd, 3rd, 4th model)"), + pytest.param(1, True, {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"), + ], +) +def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files): """Test ModelCheckpoint options.""" - def mock_save_function(filepath): - open(filepath, 'a').close() + def mock_save_function(filepath, *args): + open(filepath, "a").close() # simulated losses losses = [10, 9, 2.8, 5, 2.5] - checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, prefix=file_prefix, verbose=1) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + monitor='checkpoint_on', + save_top_k=save_top_k, + save_last=save_last, + verbose=True + ) checkpoint_callback.save_function = mock_save_function trainer = Trainer() + trainer.state = TrainerState.FITTING # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i - trainer.callback_metrics = {'val_loss': loss} - checkpoint_callback.on_validation_end(trainer, trainer.get_model()) + trainer.global_step = i + trainer.logger_connector.callback_metrics = {"checkpoint_on": torch.tensor(loss)} + checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) file_lists = set(os.listdir(tmpdir)) - assert len(file_lists) == len(expected_files), \ - "Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k) + assert len(file_lists) == len( + expected_files + ), f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" # verify correct naming for fname in expected_files: assert fname in file_lists -def test_model_freeze_unfreeze(): +def test_model_checkpoint_only_weights(tmpdir): + """Tests use case where ModelCheckpoint is configured to save only model weights, and + user tries to load checkpoint to resume training. + """ + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_weights_only=True)], + ) + # fit model + trainer.fit(model) + # training complete + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - model = EvalModelTemplate(tutils.get_default_hparams()) + checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] + # assert saved checkpoint has no trainer data + checkpoint = torch.load(checkpoint_path) + assert "optimizer_states" not in checkpoint, "checkpoint should contain only model weights" + assert "lr_schedulers" not in checkpoint, "checkpoint should contain only model weights" + + # assert loading model works when checkpoint has only weights + assert EvalModelTemplate.load_from_checkpoint(checkpoint_path=checkpoint_path) + + # directly save model + new_weights_path = os.path.join(tmpdir, "save_test.ckpt") + trainer.save_checkpoint(new_weights_path, weights_only=True) + # assert saved checkpoint has no trainer data + checkpoint = torch.load(new_weights_path) + assert "optimizer_states" not in checkpoint, "checkpoint should contain only model weights" + assert "lr_schedulers" not in checkpoint, "checkpoint should contain only model weights" + + # assert restoring train state fails + with pytest.raises(KeyError, match="checkpoint contains only the model"): + trainer.checkpoint_connector.restore_training_state(checkpoint) + + +def test_model_freeze_unfreeze(): + model = EvalModelTemplate() model.freeze() model.unfreeze() -def test_resume_from_checkpoint_epoch_restored(tmpdir): +@pytest.mark.parametrize("url_ckpt", [True, False]) +def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs""" + # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir + monkeypatch.setenv("TORCH_HOME", tmpdir) - hparams = tutils.get_default_hparams() - - def _new_model(): - # Create a model that tracks epochs and batches seen - model = EvalModelTemplate(hparams) - model.num_epochs_seen = 0 - model.num_batches_seen = 0 - model.num_on_load_checkpoint_called = 0 + class TestModel(BoringModel): + # Model that tracks epochs and batches seen + num_epochs_end_seen = 0 + num_batches_seen = 0 + num_on_load_checkpoint_called = 0 - def increment_epoch(self): - self.num_epochs_seen += 1 + def on_epoch_end(self): + self.num_epochs_end_seen += 1 - def increment_batch(self, _): + def on_train_batch_start(self, *_): self.num_batches_seen += 1 - def increment_on_load_checkpoint(self, _): + def on_load_checkpoint(self, _): self.num_on_load_checkpoint_called += 1 - # Bind methods to keep track of epoch numbers, batch numbers it has seen - # as well as number of times it has called on_load_checkpoint() - model.on_epoch_end = types.MethodType(increment_epoch, model) - model.on_batch_start = types.MethodType(increment_batch, model) - model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model) - return model - - model = _new_model() - - trainer_options = dict( - progress_bar_refresh_rate=0, + model = TestModel() + trainer = Trainer( max_epochs=2, - train_percent_check=0.65, - val_percent_check=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), + limit_train_batches=0.65, + limit_val_batches=1, + callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=-1)], default_root_dir=tmpdir, - early_stop_callback=False, - val_check_interval=1., + val_check_interval=1.0, + progress_bar_refresh_rate=0, + logger=False, + weights_summary=None, ) - - trainer = Trainer(**trainer_options) - # fit model trainer.fit(model) - training_batches = trainer.num_training_batches - - assert model.num_epochs_seen == 2 - assert model.num_batches_seen == training_batches * 2 + # `on_epoch_end` will be called once for val_sanity, twice for train, twice for val + assert model.num_epochs_end_seen == 1 + 2 + 2 + assert model.num_batches_seen == trainer.num_training_batches * 2 assert model.num_on_load_checkpoint_called == 0 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported - checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) + checkpoints = Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt") + if url_ckpt: + # transform local paths into url checkpoints + ip, port = tmpdir_server + checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints] - for check in checkpoints: - next_model = _new_model() - state = torch.load(check) + for ckpt in checkpoints: + next_model = TestModel() + state = pl_load(ckpt) # Resume training - trainer_options['max_epochs'] = 2 - new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) + new_trainer = Trainer( + default_root_dir=tmpdir, + resume_from_checkpoint=ckpt, + max_epochs=2, + ) new_trainer.fit(next_model) - assert state['global_step'] + next_model.num_batches_seen == training_batches * trainer_options['max_epochs'] + assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs assert next_model.num_on_load_checkpoint_called == 1 -def _init_steps_model(): - """private method for initializing a model with 5% train epochs""" - model = EvalModelTemplate(tutils.get_default_hparams()) - - # define train epoch to 5% of data - train_percent = 0.5 - # get number of samples in 1 epoch - num_train_samples = math.floor(len(model.train_dataloader()) * train_percent) - - trainer_options = dict( - train_percent_check=train_percent, - ) - return model, trainer_options, num_train_samples - - def test_trainer_max_steps_and_epochs(tmpdir): """Verify model trains according to specified max steps""" - model, trainer_options, num_train_samples = _init_steps_model() + model = BoringModel() + num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) # define less train steps than epochs - trainer_options.update( - default_root_dir=tmpdir, - max_epochs=3, - max_steps=num_train_samples + 10 - ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result == 1, "Training did not complete" + trainer_kwargs = { + 'limit_train_batches': 0.5, + 'default_root_dir': tmpdir, + 'max_epochs': 3, + 'max_steps': num_train_samples + 10, + 'logger': False, + 'weights_summary': None, + 'progress_bar_refresh_rate': 0, + } + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) - # check training stopped at max_steps + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps" # define less train epochs than steps - trainer_options.update( - max_epochs=2, - max_steps=trainer_options['max_epochs'] * 2 * num_train_samples - ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result == 1, "Training did not complete" + trainer_kwargs['max_epochs'] = 2 + trainer_kwargs['max_steps'] = 3 * 2 * num_train_samples + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) - # check training stopped at max_epochs + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.global_step == num_train_samples * trainer.max_epochs assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs" def test_trainer_min_steps_and_epochs(tmpdir): """Verify model trains according to specified min steps""" - model, trainer_options, num_train_samples = _init_steps_model() + model = EvalModelTemplate() + num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) + + trainer_kwargs = { + 'limit_train_batches': 0.5, + 'default_root_dir': tmpdir, + # define callback for stopping the model + 'callbacks': [EarlyStopping(monitor="early_stop_on", min_delta=1.0)], + 'val_check_interval': 2, + 'min_epochs': 1, + 'max_epochs': 7, + # define less min steps than 1 epoch + 'min_steps': num_train_samples // 2, + 'logger': False, + 'weights_summary': None, + 'progress_bar_refresh_rate': 0, + } + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) - # define callback for stopping the model and default epochs - trainer_options.update( - default_root_dir=tmpdir, - early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0), - val_check_interval=2, - min_epochs=1, - max_epochs=5 - ) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED + assert trainer.current_epoch > 0 + assert trainer.global_step >= num_train_samples, "Model did not train for at least min_epochs" - # define less min steps than 1 epoch - trainer_options['min_steps'] = math.floor(num_train_samples / 2) + # define less epochs than min_steps + trainer_kwargs["min_steps"] = math.floor(num_train_samples * 1.5) + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result == 1, "Training did not complete" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.current_epoch > 0 + assert trainer.global_step >= math.floor(num_train_samples * 1.5), "Model did not train for at least min_steps" - # check model ran for at least min_epochs - assert trainer.global_step >= num_train_samples and \ - trainer.current_epoch > 0, "Model did not train for at least min_epochs" - # define less epochs than min_steps - trainer_options['min_steps'] = math.floor(num_train_samples * 1.5) +def test_trainer_max_steps_accumulate_batches(tmpdir): + """Verify model trains according to specified max steps with grad accumulated batches""" + model = BoringModel() + num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result == 1, "Training did not complete" + # define less train steps than epochs + trainer = Trainer( + limit_train_batches=0.5, + default_root_dir=tmpdir, + max_steps=num_train_samples + 10, + accumulate_grad_batches=10, + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + trainer.fit(model) - # check model ran for at least num_train_samples*1.5 - assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \ - trainer.current_epoch > 0, "Model did not train for at least min_steps" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.state == TrainerState.FINISHED + assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps" def test_benchmark_option(tmpdir): """Verify benchmark option.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple # verify torch.backends.cudnn.benchmark is not turned on @@ -441,42 +583,134 @@ def test_benchmark_option(tmpdir): max_epochs=1, benchmark=True, ) - result = trainer.fit(model) + trainer.fit(model) # verify training completed - assert result == 1 + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify torch.backends.cudnn.benchmark is not turned off assert torch.backends.cudnn.benchmark -def test_testpass_overrides(tmpdir): - # todo: check duplicated tests against trainer_checks - hparams = tutils.get_default_hparams() +@pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) +@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) +@pytest.mark.parametrize("fn", ("validate", "test")) +def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): + + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + self.log("foo", -batch_idx) + return super().validation_step(batch, batch_idx) + + model = TestModel() + trainer = Trainer( + max_epochs=2, + progress_bar_refresh_rate=0, + default_root_dir=tmpdir, + callbacks=[ModelCheckpoint(monitor="foo", save_top_k=save_top_k)], + ) + trainer.fit(model) + + test_or_validate = getattr(trainer, fn) + if ckpt_path == "best": + # ckpt_path is 'best', meaning we load the best weights + if save_top_k == 0: + with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): + test_or_validate(ckpt_path=ckpt_path) + else: + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + else: + assert trainer.validated_ckpt_path == trainer.checkpoint_callback.best_model_path + elif ckpt_path is None: + # ckpt_path is None, meaning we don't load any checkpoints and + # use the weights from the end of training + test_or_validate(ckpt_path=ckpt_path) + assert trainer.tested_ckpt_path is None + assert trainer.validated_ckpt_path is None + else: + # specific checkpoint, pick one from saved ones + if save_top_k == 0: + with pytest.raises(FileNotFoundError): + test_or_validate(ckpt_path="random.ckpt") + else: + ckpt_path = str( + list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir() + )[0].absolute() + ) + test_or_validate(ckpt_path=ckpt_path) + if fn == "test": + assert trainer.tested_ckpt_path == ckpt_path + else: + assert trainer.validated_ckpt_path == ckpt_path + + +def test_disabled_training(tmpdir): + """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" + + class CurrentModel(BoringModel): + + training_step_invoked = False + training_epoch_end_invoked = False + + def training_step(self, *args, **kwargs): + self.training_step_invoked = True + return super().training_step(*args, **kwargs) + + def training_epoch_end(self, *args, **kwargs): + self.training_epoch_end_invoked = True + return super().training_epoch_end(*args, **kwargs) + + model = CurrentModel() + + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=2, + limit_train_batches=0.0, + limit_val_batches=0.2, + fast_dev_run=False, + ) + + before_state_dict = deepcopy(model.state_dict()) + + trainer = Trainer(**trainer_options) + trainer.fit(model) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])) + + # check that limit_train_batches=0 turns off training + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.current_epoch == 0 + assert not model.training_step_invoked, "`training_step` should not run when `limit_train_batches=0`" + assert not model.training_epoch_end_invoked, "`training_epoch_end` should not run when `limit_train_batches=0`" + + # check that limit_train_batches has no influence when fast_dev_run is turned on + model = CurrentModel() + trainer_options.update(fast_dev_run=True) + before_state_dict = deepcopy(model.state_dict()) - # Misconfig when neither test_step or test_end is implemented - with pytest.raises(MisconfigurationException, match='.*not implement `test_dataloader`.*'): - model = EvalModelTemplate(hparams) - model.test_dataloader = model.test_dataloader__empty - Trainer().test(model) + trainer = Trainer(**trainer_options) + trainer.fit(model) - # Misconfig when neither test_step or test_end is implemented - with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) - model.test_step = LightningModule.test_step - Trainer().test(model) + after_state_dict = model.state_dict() - # No exceptions when one or both of test_step or test_end are implemented - model = EvalModelTemplate(hparams) - model.test_step_end = LightningModule.test_step_end - Trainer().test(model) + for key in before_state_dict.keys(): + assert not torch.all(torch.eq(before_state_dict[key], after_state_dict[key])) - model = EvalModelTemplate(hparams) - Trainer().test(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.current_epoch == 0 + assert model.training_step_invoked, "did not run `training_step` with `fast_dev_run=True`" + assert model.training_epoch_end_invoked, "did not run `training_epoch_end` with `fast_dev_run=True`" -def test_disabled_validation(): - """Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`.""" +def test_disabled_validation(tmpdir): + """Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`.""" class CurrentModel(EvalModelTemplate): @@ -491,40 +725,38 @@ def validation_epoch_end(self, *args, **kwargs): self.validation_epoch_end_invoked = True return super().validation_epoch_end(*args, **kwargs) - hparams = tutils.get_default_hparams() - model = CurrentModel(hparams) + hparams = EvalModelTemplate.get_default_hparams() + model = CurrentModel(**hparams) trainer_options = dict( + default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=2, - train_percent_check=0.4, - val_percent_check=0.0, + limit_train_batches=0.4, + limit_val_batches=0.0, fast_dev_run=False, ) trainer = Trainer(**trainer_options) result = trainer.fit(model) - # check that val_percent_check=0 turns off validation - assert result == 1, 'training failed to complete' + # check that limit_val_batches=0 turns off validation + assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.current_epoch == 1 - assert not model.validation_step_invoked, \ - '`validation_step` should not run when `val_percent_check=0`' - assert not model.validation_epoch_end_invoked, \ - '`validation_epoch_end` should not run when `val_percent_check=0`' + assert not model.validation_step_invoked, "`validation_step` should not run when `limit_val_batches=0`" + assert not model.validation_epoch_end_invoked, "`validation_epoch_end` should not run when `limit_val_batches=0`" - # check that val_percent_check has no influence when fast_dev_run is turned on - model = CurrentModel(hparams) + # check that limit_val_batches has no influence when fast_dev_run is turned on + model = CurrentModel(**hparams) trainer_options.update(fast_dev_run=True) trainer = Trainer(**trainer_options) - result = trainer.fit(model) + trainer.fit(model) - assert result == 1, 'training failed to complete' + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert trainer.current_epoch == 0 - assert model.validation_step_invoked, \ - 'did not run `validation_step` with `fast_dev_run=True`' - assert model.validation_epoch_end_invoked, \ - 'did not run `validation_epoch_end` with `fast_dev_run=True`' + assert model.validation_step_invoked, "did not run `validation_step` with `fast_dev_run=True`" + assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`" def test_nan_loss_detection(tmpdir): @@ -536,21 +768,21 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = super().training_step(batch, batch_idx, optimizer_idx) if batch_idx == self.test_batch_inf_loss: if isinstance(output, dict): - output['loss'] *= torch.tensor(math.inf) # make loss infinite + output["loss"] *= torch.tensor(math.inf) # make loss infinite else: output /= 0 return output - model = CurrentModel(tutils.get_default_hparams()) + model = CurrentModel() # fit model trainer = Trainer( default_root_dir=tmpdir, max_steps=(model.test_batch_inf_loss + 1), - terminate_on_nan=True + terminate_on_nan=True, ) - with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'): + with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is nan or inf.*"): trainer.fit(model) assert trainer.global_step == model.test_step_inf_loss @@ -568,14 +800,14 @@ def on_after_backward(self): # simulate parameter that became nan torch.nn.init.constant_(self.c_d1.bias, math.nan) - model = CurrentModel(tutils.get_default_hparams()) + model = CurrentModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), - terminate_on_nan=True + terminate_on_nan=True, ) - with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'): + with pytest.raises(ValueError, match=r".*Detected nan and/or inf values in `c_d1.bias`.*"): trainer.fit(model) assert trainer.global_step == model.test_batch_nan @@ -587,60 +819,120 @@ def on_after_backward(self): def test_trainer_interrupted_flag(tmpdir): """Test the flag denoting that a user interrupted training.""" - model = EvalModelTemplate(tutils.get_default_hparams()) + model = EvalModelTemplate() class InterruptCallback(Callback): + def __init__(self): super().__init__() - def on_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): raise KeyboardInterrupt + class HandleInterruptCallback(Callback): + + def __init__(self): + super().__init__() + self.exc_info = None + + def on_keyboard_interrupt(self, trainer, pl_module): + self.exc_info = sys.exc_info() + interrupt_callback = InterruptCallback() + handle_interrupt_callback = HandleInterruptCallback() trainer = Trainer( - callbacks=[interrupt_callback], + callbacks=[interrupt_callback, handle_interrupt_callback], max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2, + limit_val_batches=0.1, + limit_train_batches=0.2, progress_bar_refresh_rate=0, logger=False, default_root_dir=tmpdir, ) assert not trainer.interrupted + assert handle_interrupt_callback.exc_info is None trainer.fit(model) assert trainer.interrupted + assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) def test_gradient_clipping(tmpdir): """ Test gradient clipping """ + tutils.reset_seed() + + model = EvalModelTemplate() + + trainer = Trainer( + max_steps=1, + max_epochs=1, + gradient_clip_val=1.0, + default_root_dir=tmpdir, + ) - model = EvalModelTemplate(tutils.get_default_hparams()) + trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward - # test that gradient is clipped correctly - def _optimizer_step(*args, **kwargs): + def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # test that gradient is clipped correctly + ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) parameters = model.parameters() grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm) - trainer = Trainer(max_steps=1, - max_epochs=1, - gradient_clip_val=1.0, - default_root_dir=tmpdir) + return ret_val + trainer.train_loop.training_step_and_backward = training_step_and_backward # for the test - model.optimizer_step = _optimizer_step model.prev_called_batch_idx = 0 trainer.fit(model) -def test_gpu_choice(tmpdir): - trainer_options = dict( - default_save_path=tmpdir, +@RunIf(min_gpus=1, amp_native=True) +def test_gradient_clipping_fp16(tmpdir): + """ + Test gradient clipping with fp16 + """ + tutils.reset_seed() + + model = EvalModelTemplate() + + trainer = Trainer( + max_steps=1, + max_epochs=1, + precision=16, + gpus=1, + gradient_clip_val=1.0, + default_root_dir=tmpdir, ) + + trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + + def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # test that gradient is clipped correctly + ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + parameters = model.parameters() + grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm) + + return ret_val + + trainer.train_loop.training_step_and_backward = training_step_and_backward + model.prev_called_batch_idx = 0 + + trainer.fit(model) + + +def test_gpu_choice(tmpdir): + trainer_options = dict(default_root_dir=tmpdir) # Only run if CUDA is available if not torch.cuda.is_available(): return @@ -648,95 +940,849 @@ def test_gpu_choice(tmpdir): num_gpus = torch.cuda.device_count() Trainer(**trainer_options, gpus=num_gpus, auto_select_gpus=True) - with pytest.raises(RuntimeError, match=r'.*No GPUs available.*'): + with pytest.raises(RuntimeError, match=r".*No GPUs available.*"): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) -@pytest.mark.parametrize("trainer_kwargs,expected", [ - pytest.param( - dict(distributed_backend=None, gpus=None), - dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1) - ), - pytest.param( - dict(distributed_backend="dp", gpus=None), - dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1) - ), - pytest.param( - dict(distributed_backend="dp", gpus=None), - dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1) - ), - pytest.param( - dict(distributed_backend="ddp", gpus=None), - dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1) - ), - pytest.param( - dict(distributed_backend="ddp", num_processes=2, gpus=None), - dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=2) - ), - pytest.param( - dict(distributed_backend="ddp", num_nodes=2, gpus=None), - dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1) - ), - pytest.param( - dict(distributed_backend="ddp_cpu", num_processes=2, gpus=None), - dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=2) - ), - pytest.param( - dict(distributed_backend="ddp2", gpus=None), - dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=1) - ), - pytest.param( - dict(distributed_backend=None, gpus=1), - dict(use_dp=False, use_ddp=False, use_ddp2=False, num_gpus=1, on_gpu=True, single_gpu=True, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")] - ), - pytest.param( - dict(distributed_backend="dp", gpus=1), - dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=1, on_gpu=True, single_gpu=True, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")] - ), - pytest.param( - dict(distributed_backend="ddp", gpus=1), - dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=1, on_gpu=True, single_gpu=True, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")] - ), - pytest.param( - dict(distributed_backend="ddp_cpu", num_processes=2, gpus=1), - dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=0, on_gpu=False, single_gpu=False, num_processes=2), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")] - ), - pytest.param( - dict(distributed_backend="ddp2", gpus=1), - dict(use_dp=False, use_ddp=False, use_ddp2=True, num_gpus=1, on_gpu=True, single_gpu=False, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")] - ), - pytest.param( - dict(distributed_backend=None, gpus=2), - dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")] - ), - pytest.param( - dict(distributed_backend="dp", gpus=2), - dict(use_dp=True, use_ddp=False, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")] - ), - pytest.param( - dict(distributed_backend="ddp", gpus=2), - dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=2), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")] - ), - pytest.param( - dict(distributed_backend="ddp2", gpus=2), - dict(use_dp=False, use_ddp=False, use_ddp2=True, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")] - ), -]) -def test_trainer_config(trainer_kwargs, expected): +@pytest.mark.parametrize( + ["limit_val_batches"], + [ + pytest.param(0.0), # this should run no sanity checks + pytest.param(1), + pytest.param(1.0), + pytest.param(0.5), + pytest.param(5), + ], +) +def test_num_sanity_val_steps(tmpdir, limit_val_batches): + """ + Test that the number of sanity check batches is clipped to `limit_val_batches`. + """ + model = EvalModelTemplate() + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + num_sanity_val_steps = 4 + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=num_sanity_val_steps, + limit_val_batches=limit_val_batches, + max_steps=1, + ) + assert trainer.num_sanity_val_steps == num_sanity_val_steps + + with patch.object( + trainer.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_step + ) as mocked: + val_dataloaders = model.val_dataloader__multiple_mixed_length() + trainer.fit(model, val_dataloaders=val_dataloaders) + + assert mocked.call_count == sum( + min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches + ) + + +@pytest.mark.parametrize( + ["limit_val_batches"], + [ + pytest.param(0.0), # this should run no sanity checks + pytest.param(1), + pytest.param(1.0), + pytest.param(0.3), + ], +) +def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): + """ + Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as + limited by `limit_val_batches` Trainer argument. + """ + model = EvalModelTemplate() + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=-1, + limit_val_batches=limit_val_batches, + max_steps=1, + ) + assert trainer.num_sanity_val_steps == float("inf") + + with patch.object( + trainer.evaluation_loop, "evaluation_step", wraps=trainer.evaluation_loop.evaluation_step + ) as mocked: + val_dataloaders = model.val_dataloader__multiple() + trainer.fit(model, val_dataloaders=val_dataloaders) + + assert mocked.call_count == sum(trainer.num_val_batches) + + +@pytest.mark.parametrize( + "trainer_kwargs,expected", + [ + ( + dict(accelerator=None, gpus=None), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator="dp", gpus=None), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator="dp", gpus=None), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator="ddp", gpus=None), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator="ddp", num_processes=2, gpus=None), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=2, + ), + ), + ( + dict(accelerator="ddp", num_nodes=2, gpus=None), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator="ddp_cpu", num_processes=2, gpus=None), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=2, + ), + ), + ( + dict(accelerator="ddp2", gpus=None), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator=None, gpus=1), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=False, + num_gpus=1, + on_gpu=True, + use_single_gpu=True, + num_processes=1, + ), + ), + ( + dict(accelerator="dp", gpus=1), + dict( + use_dp=True, + use_ddp=False, + use_ddp2=False, + num_gpus=1, + on_gpu=True, + use_single_gpu=True, + num_processes=1, + ), + ), + ( + dict(accelerator="ddp", gpus=1), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=1, + on_gpu=True, + use_single_gpu=True, + num_processes=1, + ), + ), + ( + dict(accelerator="ddp_cpu", num_processes=2, gpus=1), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=0, + on_gpu=False, + use_single_gpu=False, + num_processes=2, + ), + ), + ( + dict(accelerator="ddp2", gpus=1), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=True, + num_gpus=1, + on_gpu=True, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator=None, gpus=2), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=2, + on_gpu=True, + use_single_gpu=False, + num_processes=2, + ), + ), + ( + dict(accelerator="dp", gpus=2), + dict( + use_dp=True, + use_ddp=False, + use_ddp2=False, + num_gpus=2, + on_gpu=True, + use_single_gpu=False, + num_processes=1, + ), + ), + ( + dict(accelerator="ddp", gpus=2), + dict( + use_dp=False, + use_ddp=True, + use_ddp2=False, + num_gpus=2, + on_gpu=True, + use_single_gpu=False, + num_processes=2, + ), + ), + ( + dict(accelerator="ddp2", gpus=2), + dict( + use_dp=False, + use_ddp=False, + use_ddp2=True, + num_gpus=2, + on_gpu=True, + use_single_gpu=False, + num_processes=1, + ), + ), + ], +) +def test_trainer_config(trainer_kwargs, expected, monkeypatch): + if trainer_kwargs["gpus"] is not None: + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: trainer_kwargs["gpus"]) trainer = Trainer(**trainer_kwargs) - assert trainer.use_dp is expected["use_dp"] - assert trainer.use_ddp is expected["use_ddp"] - assert trainer.use_ddp2 is expected["use_ddp2"] - assert trainer.num_gpus == expected["num_gpus"] - assert trainer.on_gpu is expected["on_gpu"] - assert trainer.single_gpu is expected["single_gpu"] - assert trainer.num_processes == expected["num_processes"] + assert len(expected) == 7 + for k, v in expected.items(): + assert getattr(trainer, k) == v, f"Failed {k}: {v}" + + +def test_trainer_subclassing(): + model = EvalModelTemplate() + + # First way of pulling out args from signature is to list them + class TrainerSubclass(Trainer): + + def __init__(self, custom_arg, *args, custom_kwarg="test", **kwargs): + super().__init__(*args, **kwargs) + self.custom_arg = custom_arg + self.custom_kwarg = custom_kwarg + + trainer = TrainerSubclass(123, custom_kwarg="custom", fast_dev_run=True) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.custom_arg == 123 + assert trainer.custom_kwarg == "custom" + assert trainer.fast_dev_run + + # Second way is to pop from the dict + # It's a special case because Trainer does not have any positional args + class TrainerSubclass(Trainer): + + def __init__(self, **kwargs): + self.custom_arg = kwargs.pop("custom_arg", 0) + self.custom_kwarg = kwargs.pop("custom_kwarg", "test") + super().__init__(**kwargs) + + trainer = TrainerSubclass(custom_kwarg="custom", fast_dev_run=True) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + assert trainer.custom_kwarg == "custom" + assert trainer.fast_dev_run + + # when we pass in an unknown arg, the base class should complain + with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'abcdefg'"): + TrainerSubclass(abcdefg="unknown_arg") + + +@pytest.mark.parametrize( + "trainer_params", + [ + OmegaConf.create({ + "max_epochs": 1, + "gpus": 1 + }), + OmegaConf.create({ + "max_epochs": 1, + "gpus": [0] + }), + ], +) +@RunIf(min_gpus=1) +def test_trainer_omegaconf(trainer_params): + Trainer(**trainer_params) + + +def test_trainer_pickle(tmpdir): + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + ) + pickle.dumps(trainer) + cloudpickle.dumps(trainer) + + +@pytest.mark.parametrize("stage", ("fit", "validate", "test")) +def test_trainer_setup_call(tmpdir, stage): + """Test setup call gets the correct stage""" + + class CurrentModel(BoringModel): + + def setup(self, stage): + self.stage = stage + + class TrainerSubclass(Trainer): + + def setup(self, model, stage): + assert model is not None + self.stage = stage + + model = CurrentModel() + + # fit model + trainer = TrainerSubclass(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False) + + if stage == "fit": + trainer.fit(model) + elif stage == "validate": + trainer.validate(model, ckpt_path=None) + else: + trainer.test(model, ckpt_path=None) + + assert trainer.stage == stage + assert trainer.lightning_module.stage == stage + + +@pytest.mark.parametrize( + "train_batches, max_steps, log_interval", + [ + (10, 10, 1), + (3, 10, 1), + (3, 10, 5), + ], +) +@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") +def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval): + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + log_every_n_steps=log_interval, + flush_logs_every_n_steps=log_interval, + limit_train_batches=train_batches, + limit_val_batches=0, + max_steps=max_steps, + ) + trainer.fit(model) + expected_calls = [call(metrics=ANY, step=s) for s in range(log_interval - 1, max_steps, log_interval)] + log_metrics_mock.assert_has_calls(expected_calls) + + +@pytest.mark.parametrize(['profiler', 'expected'], [ + (None, PassThroughProfiler), + (SimpleProfiler(), SimpleProfiler), + (AdvancedProfiler(), AdvancedProfiler), + ('simple', SimpleProfiler), + ('Simple', SimpleProfiler), + ('advanced', AdvancedProfiler), + ('pytorch', PyTorchProfiler), +]) +def test_trainer_profiler_correct_args(profiler, expected): + kwargs = {'profiler': profiler} if profiler is not None else {} + trainer = Trainer(**kwargs) + assert isinstance(trainer.profiler, expected) + + +def test_trainer_profiler_incorrect_str_arg(): + with pytest.raises(ValueError, match=r".*can only be 'simple' or 'advanced'"): + Trainer(profiler="unknown_profiler") + + +@pytest.mark.parametrize('profiler', ( + 42, + [42], + dict(a=42), + torch.tensor(42), + Trainer(), +)) +def test_trainer_profiler_incorrect_arg_type(profiler): + with pytest.raises( + MisconfigurationException, + match="Only None, str and subclasses of `BaseProfiler`" + r" are valid values for `Trainer`'s `profiler` parameter. *" + ): + Trainer(profiler=profiler) + + +class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict_dataloader(self): + return self._dataloaders + + +def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True): + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = model or BoringModel() + dm = TestLightningDataModule(dataloaders) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + accelerator=accelerator, + gpus=gpus, + num_processes=num_processes, + plugins=plugins, + ) + if datamodule: + results = trainer.predict(model, datamodule=dm) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + # todo: address this in another PR + num_samples = 1 if accelerator in ["ddp", "ddp_cpu", "ddp_spawn"] else 2 + assert len(results) == 2 + assert len(results[0]) == num_samples + assert results[0][0].shape == torch.Size([1, 2]) + + +def test_trainer_predict_no_return(tmpdir): + """ + Test trainer.predict warns when nothing is returned + """ + + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + if (batch_idx + 1) % 2 == 0: + return + + return super().predict_step(batch, batch_idx, dataloader_idx) + + with pytest.warns(UserWarning, match='predict returned None'): + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + +def test_trainer_predict_grad(tmpdir): + + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.expand_as(batch).grad_fn is None + return super().predict_step(batch, batch_idx, dataloader_idx) + + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + x = torch.zeros(1, requires_grad=True) + assert x.expand_as(x).grad_fn is not None + + +@pytest.mark.parametrize('datamodule', [False, True]) +def test_trainer_predict_cpu(tmpdir, datamodule): + predict(tmpdir, None, None, 1, datamodule=datamodule) + + +@RunIf(min_gpus=2, special=True) +@pytest.mark.parametrize('num_gpus', [1, 2]) +def test_trainer_predict_dp(tmpdir, num_gpus): + predict(tmpdir, "dp", num_gpus, None) + + +@RunIf(min_gpus=2, special=True) +def test_trainer_predict_ddp(tmpdir): + predict(tmpdir, "ddp", 2, None, plugins=["ddp_sharded"]) + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_trainer_predict_ddp_spawn(tmpdir): + predict(tmpdir, "ddp_spawn", 2, None) + + +@RunIf(min_gpus=2, special=True) +def test_trainer_predict_1_gpu(tmpdir): + predict(tmpdir, None, 1, None) + + +@RunIf(skip_windows=True, special=True) +def test_trainer_predict_ddp_cpu(tmpdir): + predict(tmpdir, "ddp_cpu", 0, 2) + + +@pytest.mark.parametrize( + ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], + [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], +) +def test_disabled_training_for_insufficient_limit_train_batches( + tmpdir, limit_train_batches, global_step, num_training_batches, current_epoch, should_train +): + """ + Verify when `limit_train_batches` is float & between [0.0, 1.0] and + `int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled. + """ + + class CurrentModel(BoringModel): + + training_step_invoked = False + training_epoch_end_invoked = False + + def training_step(self, *args, **kwargs): + self.training_step_invoked = True + return super().training_step(*args, **kwargs) + + def training_epoch_end(self, *args, **kwargs): + self.training_epoch_end_invoked = True + return super().training_epoch_end(*args, **kwargs) + + dataset_len = 100 + batch_size = 25 + + train = RandomDataset(32, length=dataset_len) + train_loader = DataLoader(train, batch_size=batch_size) + + model = CurrentModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + limit_train_batches=limit_train_batches, + ) + result = trainer.fit(model, train_loader) + + params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}` + & `batch_size={batch_size}` as + `num_training_batches={num_training_batches}`""" + if should_train: + error_string = f"should run with {params_string}" + else: + error_string = f"should not run with {params_string}" + + assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED + assert trainer.global_step == global_step + assert trainer.num_training_batches == num_training_batches + assert trainer.current_epoch == current_epoch + assert model.training_step_invoked == should_train, f"`training_step` {error_string}" + assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}" + + +@pytest.mark.parametrize(["max_steps", "max_epochs", "global_step"], [(10, 5, 10), (20, None, 20)]) +def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step): + """ + Ensure that the training loop is bound by `max_steps` and + `max_epochs` for repeated calls of `trainer.fit`, and + disabled if the limit is reached + """ + + dataset_len = 200 + batch_size = 10 + + train_data = DataLoader(RandomDataset(32, dataset_len), batch_size=batch_size) + + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=max_steps, + max_epochs=max_epochs, + ) + trainer.fit(model, train_data) + assert trainer.global_step == global_step + trainer.fit(model, train_data) + assert trainer.global_step == global_step + + +def test_trainer_access_in_configure_optimizers(tmpdir): + """ + Verify that the configure optimizer function can reference the trainer. + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + assert self.trainer is not None, "Expect to have access to the trainer within `configure_optimizers`" + + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_data) + + +@RunIf(min_gpus=1) +def test_setup_hook_move_to_device_correctly(tmpdir): + """ + Verify that if a user defines a layer in the setup hook function, this is moved to the correct device. + """ + + class TestModel(BoringModel): + + def setup(self, stage: str) -> None: + self.new_layer = torch.nn.Linear(2, 2) + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + # will crash if not moved to correct device + output = self.new_layer(output) + loss = self.loss(batch, output) + return {"loss": loss} + + # fake data + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + + # model + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=1) + trainer.fit(model, train_data) + + +def test_train_loop_system(tmpdir): + """ + Test the following methods are called in the order in automatic optimization. + 1. optimizer.step (skip when gradient accumulation) + 2. model.training_step + 3. optimizer.zero_grad (run when the first batch of gradient accumulation) + 4. model.backward + + Note that the order is NOT `training_step`->`zero_grad`->`backward`->`step`. + This is because `optimizer.step(closure)` calls `closure()` which then calls + the three remaining methods `training_step`, `zero_grad` and `backward` inside. + """ + called_methods = [] + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=1, + limit_test_batches=1, + progress_bar_refresh_rate=0, + ) + + class TestOptimizer(SGD): + + def step(self, *args, **kwargs): + called_methods.append("step") + return super().step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + called_methods.append("zero_grad") + return super().zero_grad(*args, **kwargs) + + class TestModel(BoringModel): + + def configure_optimizers(self): + return TestOptimizer(self.parameters(), lr=0.1) + + def training_step(self, *args, **kwargs): + called_methods.append("training_step") + return super().training_step(*args, **kwargs) + + def backward(self, *args, **kwargs): + called_methods.append("backward") + return super().backward(*args, **kwargs) + + model = TestModel() + trainer = Trainer(**trainer_options) + + # No methods are called yet. + assert called_methods == [] + + trainer.fit(model) + assert called_methods == [ + "step", + "training_step", + "zero_grad", + "backward", + ] * trainer.limit_train_batches + + called_methods.clear() + trainer = Trainer(**trainer_options, accumulate_grad_batches=3) + + # No methods are called yet. + assert called_methods == [] + + trainer.fit(model) + assert called_methods == [ + # 0 + "training_step", + "zero_grad", + "backward", + # 1 + "training_step", + "backward", + # 2 + "step", + "training_step", + "backward", + # 3 + "training_step", + "zero_grad", + "backward", + # 4 + "step", + "training_step", + "backward", + ] + + +def test_init_optimizers_resets_lightning_optimizers(tmpdir): + """ Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """ + + def compare_optimizers(): + assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] + + model = BoringModel() + model.lr = 0.2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_lr_find=True, + ) + + trainer.tune(model) + compare_optimizers() + + trainer.fit(model) + compare_optimizers() + + trainer.max_epochs = 2 # simulate multiple fit calls + trainer.fit(model) + compare_optimizers() + + +def test_check_val_every_n_epoch_exception(tmpdir): + + with pytest.raises(MisconfigurationException, match="should be an integer."): + Trainer( + default_root_dir=tmpdir, + max_epochs=1, + check_val_every_n_epoch=1.2, + ) + + +def test_trainer_attach_data_pipeline_to_model(tmpdir): + + class DataPipeline: + + pass + + class TestDataModule(LightningDataModule): + + data_pipeline = DataPipeline() + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + class TestCallback(Callback): + + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + """Called when fit begins""" + assert isinstance(pl_module.data_pipeline, DataPipeline) + + model = BoringModel() + dm = TestDataModule() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) + trainer.fit(model, datamodule=dm) + + +def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + + with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): + trainer.validate() + with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"): + trainer.test() diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index c4c23d0fff4ed3..32da5d2b2fa99e 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -1,18 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect +import pickle from argparse import ArgumentParser, Namespace from unittest import mock -import pickle import pytest -import tests.base.utils as tutils +import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.utilities import argparse +from tests.helpers.runif import RunIf -@mock.patch('argparse.ArgumentParser.parse_args', - return_value=Namespace(**Trainer.default_attributes())) -def test_default_args(tmpdir): +@mock.patch('argparse.ArgumentParser.parse_args') +def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" + mock_argparse.return_value = Namespace(**Trainer.default_attributes()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -28,12 +43,8 @@ def test_default_args(tmpdir): assert trainer.max_epochs == 5 -@pytest.mark.parametrize('cli_args', [ - ['--accumulate_grad_batches=22'], - ['--print_nan_grads=1', '--weights_save_path=./'], - [] -]) -def test_add_argparse_args_redefined(cli_args): +@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []]) +def test_add_argparse_args_redefined(cli_args: list): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness. """ @@ -55,9 +66,23 @@ def test_add_argparse_args_redefined(cli_args): assert isinstance(trainer, Trainer) +@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []]) +def test_add_argparse_args(cli_args: list): + """Simple test ensuring Trainer.add_argparse_args works.""" + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args(cli_args) + assert Trainer.from_argparse_args(args) + + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser, use_argument_group=False) + args = parser.parse_args(cli_args) + assert Trainer.from_argparse_args(args) + + def test_get_init_arguments_and_types(): """Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod.""" - args = Trainer.get_init_arguments_and_types() + args = argparse.get_init_arguments_and_types(Trainer) parameters = inspect.signature(Trainer).parameters assert len(parameters) == len(args) for arg in args: @@ -68,11 +93,8 @@ def test_get_init_arguments_and_types(): assert isinstance(trainer, Trainer) -@pytest.mark.parametrize('cli_args', [ - ['--callbacks=1', '--logger'], - ['--foo', '--bar=1'] -]) -def test_add_argparse_args_redefined_error(cli_args, monkeypatch): +@pytest.mark.parametrize('cli_args', [['--callbacks=1', '--logger'], ['--foo', '--bar=1']]) +def test_add_argparse_args_redefined_error(cli_args: list, monkeypatch): """Asserts thar an error raised in case of passing not default cli arguments.""" class _UnkArgError(Exception): @@ -88,3 +110,105 @@ def _raise(): with pytest.raises(_UnkArgError): parser.parse_args(cli_args) + + +@pytest.mark.parametrize( + ['cli_args', 'expected'], + [ + pytest.param( + '--auto_lr_find --auto_scale_batch_size power', { + 'auto_lr_find': True, + 'auto_scale_batch_size': 'power' + } + ), + pytest.param( + '--auto_lr_find any_string --auto_scale_batch_size', { + 'auto_lr_find': 'any_string', + 'auto_scale_batch_size': True + } + ), + pytest.param( + '--auto_lr_find TRUE --auto_scale_batch_size FALSE', { + 'auto_lr_find': True, + 'auto_scale_batch_size': False + } + ), + pytest.param( + '--auto_lr_find t --auto_scale_batch_size ON', { + 'auto_lr_find': True, + 'auto_scale_batch_size': True + } + ), + pytest.param( + '--auto_lr_find 0 --auto_scale_batch_size n', { + 'auto_lr_find': False, + 'auto_scale_batch_size': False + } + ), + pytest.param( + "", + { + # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. + # They should not be changed by the argparse interface. + "min_steps": None, + "max_steps": None, + "log_gpu_memory": None, + "accelerator": None, + "weights_save_path": None, + "truncated_bptt_steps": None, + "resume_from_checkpoint": None, + "profiler": None, + } + ), + ] +) +def test_argparse_args_parsing(cli_args, expected): + """Test multi type argument with bool.""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + for k, v in expected.items(): + assert getattr(args, k) == v + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [ + pytest.param('--gpus 1', [0]), + pytest.param('--gpus 0,', [0]), +]) +@RunIf(min_gpus=1) +def test_argparse_args_parsing_gpus(cli_args, expected_gpu): + """Test multi type argument with bool.""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + trainer = Trainer.from_argparse_args(args) + assert trainer.data_parallel_device_ids == expected_gpu + + +@RunIf(min_python="3.7.0") +@pytest.mark.parametrize(['cli_args', 'extra_args'], [ + pytest.param({}, {}), + pytest.param({'logger': False}, {}), + pytest.param({'logger': False}, {'logger': True}), + pytest.param({'logger': False}, {'checkpoint_callback': True}), +]) +def test_init_from_argparse_args(cli_args, extra_args): + unknown_args = dict(unknown_arg=0) + + # unkown args in the argparser/namespace should be ignored + with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init: + trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) + expected = dict(cli_args) + expected.update(extra_args) # extra args should override any cli arg + init.assert_called_with(trainer, **expected) + + # passing in unknown manual args should throw an error + with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"): + Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py new file mode 100644 index 00000000000000..7206d225ab5cd3 --- /dev/null +++ b/tests/trainer/test_trainer_tricks.py @@ -0,0 +1,370 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy + +import pytest +import torch +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + +import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate +from tests.helpers import BoringModel +from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf + + +def test_num_training_batches(tmpdir): + """ + Tests that the correct number of batches are allocated + """ + # when we have fewer batches in the dataloader we should use those instead of the limit + model = EvalModelTemplate() + trainer = Trainer( + limit_val_batches=100, + limit_train_batches=100, + max_epochs=1, + default_root_dir=tmpdir, + ) + trainer.fit(model) + + assert len(model.train_dataloader()) == 10 + assert len(model.val_dataloader()) == 10 + assert isinstance(trainer.num_val_batches, list) + assert trainer.num_val_batches[0] == 10 + assert trainer.num_training_batches == 10 + + # when we have more batches in the dataloader we should limit them + model = EvalModelTemplate() + trainer = Trainer( + limit_val_batches=7, + limit_train_batches=7, + max_epochs=1, + default_root_dir=tmpdir, + ) + trainer.fit(model) + + assert len(model.train_dataloader()) == 10 + assert len(model.val_dataloader()) == 10 + assert isinstance(trainer.num_val_batches, list) + assert trainer.num_val_batches[0] == 7 + assert trainer.num_training_batches == 7 + + +def test_overfit_batch_limits(tmpdir): + # ------------------------------------------------------ + # Make sure shuffle is correct across loaders initially + # ------------------------------------------------------ + model = EvalModelTemplate() + model.train_dataloader() + + # original train loader which should be replaced in all methods + train_loader = model.train_dataloader() + + # make sure the val and tests are not shuffled + assert isinstance(train_loader.sampler, RandomSampler) + assert isinstance(model.val_dataloader().sampler, SequentialSampler) + assert isinstance(model.test_dataloader().sampler, SequentialSampler) + + # ------------------------------------------------------ + # get the training loader and batch + # ------------------------------------------------------ + # Create a reference train dataloader without shuffling. + train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) + (xa, ya) = next(iter(train_loader)) + train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True) + full_train_samples = len(train_loader) + num_train_samples = int(0.11 * full_train_samples) + + # ------------------------------------------------------ + # set VAL and Test loaders + # ------------------------------------------------------ + val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False) + test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False) + + # set the model loaders + model.train_dataloader = lambda: train_loader + model.val_dataloader = lambda: val_loader + model.test_dataloader = lambda: test_loader + + # ------------------------------------------------------ + # test train loader applies correct limits + # ------------------------------------------------------ + trainer = Trainer(overfit_batches=4) + trainer.reset_train_dataloader(model) + assert trainer.num_training_batches == 4 + + # make sure the loaders are the same + (xb, yb) = next(iter(trainer.train_dataloader)) + assert torch.eq(xa, xb).all() + assert torch.eq(ya, yb).all() + + trainer = Trainer(overfit_batches=0.11) + trainer.reset_train_dataloader(model) + # The dataloader should have been overwritten with a Sequential sampler. + assert trainer.train_dataloader is not train_loader + assert trainer.num_training_batches == num_train_samples + + # make sure the loaders are the same + (xb, yb) = next(iter(trainer.train_dataloader)) + assert torch.eq(xa, xb).all() + assert torch.eq(ya, yb).all() + + # ------------------------------------------------------ + # run tests for both val and test + # ------------------------------------------------------ + for split in ['val', 'test']: + + # ------------------------------------------------------ + # test overfit_batches as percent + # ------------------------------------------------------ + loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == num_train_samples + + # make sure we turned off shuffle for the user + assert isinstance(dataloaders[0].sampler, SequentialSampler) + + # make sure the loaders are the same + (xb, yb) = next(iter(dataloaders[0])) + assert torch.eq(xa, xb).all() + assert torch.eq(ya, yb).all() + + # ------------------------------------------------------ + # test overfit_batches as int + # ------------------------------------------------------ + loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 1 + loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 5 + + # ------------------------------------------------------ + # test limit_xxx_batches as percent AND int + # ------------------------------------------------------ + if split == 'val': + loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == int(0.1 * len(val_loader)) + + loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 10 + else: + loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == int(0.1 * len(test_loader)) + + loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 10 + + +def test_model_reset_correctly(tmpdir): + """ Check that model weights are correctly reset after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + before_state_dict = deepcopy(model.state_dict()) + + trainer.tuner.scale_batch_size(model, max_trials=5) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ + 'Model was not reset correctly after scaling batch size' + + +def test_trainer_reset_correctly(tmpdir): + """ Check that all trainer parameters are reset correctly after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + changed_attributes = [ + 'max_steps', + 'weights_summary', + 'logger', + 'callbacks', + 'checkpoint_callback', + 'limit_train_batches', + 'current_epoch', + ] + + attributes_before = {} + for ca in changed_attributes: + attributes_before[ca] = getattr(trainer, ca) + + trainer.tuner.scale_batch_size(model, max_trials=5) + + attributes_after = {} + for ca in changed_attributes: + attributes_after[ca] = getattr(trainer, ca) + + for key in changed_attributes: + assert attributes_before[key] == attributes_after[key], \ + f'Attribute {key} was not reset correctly after learning rate finder' + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True]) +def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg): + """ Test possible values for 'batch size auto scaling' Trainer argument. """ + tutils.reset_seed() + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + before_batch_size = hparams.get('batch_size') + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_scale_batch_size=scale_arg, + gpus=1, + ) + trainer.tune(model) + after_batch_size = model.batch_size + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + assert not os.path.exists(tmpdir / 'scale_batch_size_temp_model.ckpt') + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('use_hparams', [True, False]) +def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): + """ Test that new batch size gets written to the correct hyperparameter attribute. """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + before_batch_size = hparams.get('batch_size') + + class HparamsEvalModelTemplate(EvalModelTemplate): + + def dataloader(self, *args, **kwargs): + # artificially set batch_size so we can get a dataloader + # remove it immediately after, because we want only self.hparams.batch_size + setattr(self, "batch_size", before_batch_size) + dataloader = super().dataloader(*args, **kwargs) + del self.batch_size + return dataloader + + datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! + datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + + model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate + model = model_class(**hparams) + model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_scale_batch_size=True, + gpus=1, + ) + trainer.tune(model, datamodule_fit) + after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size + assert trainer.datamodule == datamodule_fit + assert before_batch_size != after_batch_size + assert after_batch_size <= len(trainer.train_dataloader.dataset) + assert datamodule_fit.batch_size == after_batch_size + # should be left unchanged, since it was not passed to .tune() + assert datamodule_model.batch_size == 111 + + +def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): + """ Test for a warning when model.batch_size and model.hparams.batch_size both present. """ + + class TestModel(BoringModel): + + def __init__(self, batch_size=1): + super().__init__() + # now we have model.batch_size and model.hparams.batch_size + self.batch_size = 1 + self.save_hyperparameters() + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True) + expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!" + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) + + +@pytest.mark.parametrize('scale_method', ['power', 'binsearch']) +def test_call_to_trainer_method(tmpdir, scale_method): + """ Test that calling the trainer method itself works. """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_batch_size = hparams.get('batch_size') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) + model.batch_size = after_batch_size + trainer.fit(model) + + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + +def test_error_on_dataloader_passed_to_fit(tmpdir): + """Verify that when the auto scale batch size feature raises an error + if a train dataloader is passed to fit """ + + # only train passed to fit + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + auto_scale_batch_size='power', + ) + fit_options = dict(train_dataloader=model.dataloader(train=True)) + + with pytest.raises(MisconfigurationException): + trainer.tune(model, **fit_options) + + +@RunIf(min_gpus=1, amp_native=True) +def test_auto_scale_batch_size_with_amp(tmpdir): + model = EvalModelTemplate() + batch_size_before = model.batch_size + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + auto_scale_batch_size=True, + gpus=1, + precision=16, + ) + trainer.tune(model) + batch_size_after = model.batch_size + assert trainer.amp_backend == AMPType.NATIVE + assert trainer.scaler is not None + assert batch_size_after != batch_size_before diff --git a/tests/tuner/__init__.py b/tests/tuner/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/tuner/test_auto_gpu_select.py b/tests/tuner/test_auto_gpu_select.py new file mode 100644 index 00000000000000..32ec0282c8ce47 --- /dev/null +++ b/tests/tuner/test_auto_gpu_select.py @@ -0,0 +1,69 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.runif import RunIf + + +@RunIf(min_gpus=2) +@pytest.mark.parametrize( + ["auto_select_gpus", "gpus", "expected_error"], + [ + (True, 0, MisconfigurationException), + (True, -1, None), + (False, 0, None), + (False, -1, None), + ], +) +def test_trainer_with_gpus_options_combination_at_available_gpus_env(auto_select_gpus, gpus, expected_error): + if expected_error: + with pytest.raises( + expected_error, + match=re.escape( + r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ + Please select a valid number of GPU resources when using auto_select_gpus." + ), + ): + Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus) + else: + Trainer(auto_select_gpus=auto_select_gpus, gpus=gpus) + + +@RunIf(min_gpus=2) +@pytest.mark.parametrize( + ["nb", "expected_gpu_idxs", "expected_error"], + [ + (0, [], MisconfigurationException), + (-1, [i for i in range(torch.cuda.device_count())], None), + (1, [0], None), + ], +) +def test_pick_multiple_gpus(nb, expected_gpu_idxs, expected_error): + if expected_error: + with pytest.raises( + expected_error, + match=re.escape( + r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ + Please select a valid number of GPU resources when using auto_select_gpus." + ), + ): + pick_multiple_gpus(nb) + else: + assert expected_gpu_idxs == pick_multiple_gpus(nb) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py new file mode 100644 index 00000000000000..ad7fc57092f321 --- /dev/null +++ b/tests/tuner/test_scale_batch_size.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.tuning import Tuner +from tests.helpers import BoringDataModule, BoringModel + + +class BatchSizeDataModule(BoringDataModule): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1)) + + +class BatchSizeModel(BoringModel): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + +@pytest.mark.parametrize( + "model,datamodule", [ + (BatchSizeModel(2), None), + (BatchSizeModel(2), BatchSizeDataModule(2)), + (BatchSizeModel(2), BatchSizeDataModule(None)), + (BatchSizeModel(None), BatchSizeDataModule(2)), + ] +) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=1, + ) + tuner = Tuner(trainer) + new_batch_size = tuner.scale_batch_size( + model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule + ) + assert new_batch_size == 16 + if hasattr(model, "batch_size"): + assert model.batch_size == 16 + if datamodule is not None and hasattr(datamodule, "batch_size"): + assert datamodule.batch_size == 16 diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/utilities/distributed.py b/tests/utilities/distributed.py new file mode 100644 index 00000000000000..80c0246ce6c577 --- /dev/null +++ b/tests/utilities/distributed.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from pathlib import Path +from subprocess import TimeoutExpired + +import pytorch_lightning + + +def call_training_script(module_file, cli_args, method, tmpdir, timeout=60): + file = Path(module_file.__file__).absolute() + cli_args = cli_args.split(' ') if cli_args else [] + cli_args += ['--tmpdir', str(tmpdir)] + cli_args += ['--trainer_method', method] + command = [sys.executable, str(file)] + cli_args + + # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment + env = os.environ.copy() + env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:' + + # for running in ddp mode, we need to lauch it's own process or pytest will get stuck + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) + try: + std, err = p.communicate(timeout=timeout) + err = str(err.decode("utf-8")) + if 'Exception' in err: + raise Exception(err) + except TimeoutExpired: + p.kill() + std, err = p.communicate() + return std, err diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py new file mode 100644 index 00000000000000..6bad31634ce830 --- /dev/null +++ b/tests/utilities/test_all_gather_grad.py @@ -0,0 +1,120 @@ +import os +import sys + +import numpy as np +import torch + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities import AllGatherGrad +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +def setup_ddp(rank, world_size): + """ Setup ddp enviroment """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" + + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _test_all_gather_ddp(rank, world_size): + setup_ddp(rank, world_size) + + tensor1 = torch.ones(8, requires_grad=True) + tensor2 = torch.ones((8, 16, 32), requires_grad=True) + + tensor1_gathered = AllGatherGrad.apply(tensor1) + tensor2_gathered = AllGatherGrad.apply(tensor2) + + tensor1_gathered = tensor1_gathered * rank + tensor2_gathered = tensor2_gathered * rank + + tensor1_gathered.sum().backward() + tensor2_gathered.sum().backward() + + grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float()) + grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float()) + + assert torch.allclose(grad1, tensor1.grad) + assert torch.allclose(grad2, tensor2.grad) + + +@RunIf(skip_windows=True) +def test_all_gather_ddp(): + world_size = 3 + torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size, ), nprocs=world_size) + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_all_gather_collection(tmpdir): + + class TestModel(BoringModel): + + training_epoch_end_called = False + + def training_epoch_end(self, outputs) -> None: + losses = torch.stack([x["loss"] for x in outputs]) + gathered_loss = self.all_gather({ + "losses_tensor_int": torch.rand(2, 2).int().t(), + "losses_tensor_float": torch.rand(2, 2).t(), + "losses_np_ndarray": np.array([1, 2, 3]), + "losses_bool": [True, False], + "losses_float": [0., 1., 2.], + "losses_int": [0, 1, 2], + "losses": losses, + "losses_list": [losses, losses] + }) + assert gathered_loss["losses_tensor_int"][0].dtype == torch.int32 + assert gathered_loss["losses_tensor_float"][0].dtype == torch.float + assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 + # torch.bool can't be all_gathered + assert gathered_loss["losses_bool"][0].dtype == torch.uint8 + assert gathered_loss["losses_float"][0].dtype == torch.float + assert gathered_loss["losses_int"][0].dtype == torch.int + assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) + assert gathered_loss["losses"].numel() == 2 * len(losses) + self.training_epoch_end_called = True + + seed_everything(42) + + model = TestModel() + + limit_train_batches = 8 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + gpus=2, + accelerator="ddp", + ) + + trainer.fit(model) + assert model.training_epoch_end_called + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_all_gather_sync_grads(tmpdir): + + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=True) + assert gathered_tensor.shape == torch.Size([2, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp") + trainer.fit(model) + assert model.training_step_called diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py new file mode 100644 index 00000000000000..a7eea3a749f26d --- /dev/null +++ b/tests/utilities/test_apply_func.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +from collections import namedtuple + +import numpy as np +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def test_recursive_application_to_collection(): + ntc = namedtuple('Foo', ['bar']) + + to_reduce = { + 'a': torch.tensor([1.]), # Tensor + 'b': [torch.tensor([2.])], # list + 'c': (torch.tensor([100.]), ), # tuple + 'd': ntc(bar=5.), # named tuple + 'e': np.array([10.]), # numpy array + 'f': 'this_is_a_dummy_str', # string + 'g': 12. # number + } + + expected_result = { + 'a': torch.tensor([2.]), + 'b': [torch.tensor([4.])], + 'c': (torch.tensor([200.]), ), + 'd': ntc(bar=torch.tensor([10.])), + 'e': np.array([20.]), + 'f': 'this_is_a_dummy_str', + 'g': 24. + } + + reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) + + assert isinstance(reduced, dict), ' Type Consistency of dict not preserved' + assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved' + assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \ + 'At least one type was not correctly preserved' + + assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor' + assert torch.allclose(expected_result['a'], reduced['a']), \ + 'Reduction of a tensor does not yield the expected value' + + assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list' + assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \ + 'At least one value of list reduction did not come out as expected' + + assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple' + assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \ + 'At least one value of tuple reduction did not come out as expected' + + assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given' + assert isinstance(reduced['d'].bar, numbers.Number), \ + 'Failure in type promotion while reducing fields of named tuples' + assert reduced['d'].bar == expected_result['d'].bar + + assert isinstance(reduced['e'], np.ndarray), 'Type Promotion in reduction of numpy arrays failed' + assert reduced['e'] == expected_result['e'], \ + 'Reduction of numpy array did not yield the expected result' + + assert isinstance(reduced['f'], str), 'A string should not be reduced' + assert reduced['f'] == expected_result['f'], 'String not preserved during reduction' + + assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' + assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py new file mode 100644 index 00000000000000..fa5b8222715334 --- /dev/null +++ b/tests/utilities/test_apply_func_torchtext.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from pytorch_lightning.utilities.apply_func import move_data_to_device +from tests.helpers.imports import Dataset, Example, Field, Iterator +from tests.helpers.runif import RunIf + + +def _get_torchtext_data_iterator(include_lengths=False): + text_field = Field( + sequential=True, + pad_first=False, # nosec + init_token="", + eos_token="", # nosec + include_lengths=include_lengths + ) # nosec + + example1 = Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)}) + example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) + example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) + + dataset = Dataset( + [example1, example2, example3], + {"text": text_field}, + ) + text_field.build_vocab(dataset) + + iterator = Iterator( + dataset, + batch_size=3, + sort_key=None, + device=None, + batch_size_fn=None, + train=True, + repeat=False, + shuffle=None, + sort=None, + sort_within_batch=None + ) + return iterator, text_field + + +@pytest.mark.parametrize('include_lengths', [False, True]) +@pytest.mark.parametrize(['device'], [pytest.param(torch.device('cuda', 0))]) +@RunIf(min_gpus=1) +def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device): + data_iterator, _ = _get_torchtext_data_iterator(include_lengths=include_lengths) + data_iter = iter(data_iterator) + batch = next(data_iter) + batch_on_device = move_data_to_device(batch, device) + + if include_lengths: + # tensor with data + assert (batch_on_device.text[0].device == device) + # tensor with length of data + assert (batch_on_device.text[1].device == device) + else: + assert (batch_on_device.text.device == device) + + +@pytest.mark.parametrize('include_lengths', [False, True]) +def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths): + test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device('cpu')) diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py new file mode 100644 index 00000000000000..f13af4362364ca --- /dev/null +++ b/tests/utilities/test_argparse.py @@ -0,0 +1,215 @@ +import io +from argparse import ArgumentParser, Namespace +from typing import List +from unittest.mock import MagicMock + +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.argparse import ( + _gpus_arg_default, + _int_or_float_type, + add_argparse_args, + from_argparse_args, + get_abbrev_qualified_cls_name, + parse_argparser, + parse_args_from_docstring, +) + + +class ArgparseExample: + + def __init__(self, a: int = 0, b: str = '', c: bool = False): + self.a = a + self.b = b + self.c = c + + +def test_from_argparse_args(): + args = Namespace(a=1, b='test', c=True, d='not valid') + my_instance = from_argparse_args(ArgparseExample, args) + assert my_instance.a == 1 + assert my_instance.b == 'test' + assert my_instance.c + + parser = ArgumentParser() + mock_trainer = MagicMock() + _ = from_argparse_args(mock_trainer, parser) + mock_trainer.parse_argparser.assert_called_once_with(parser) + + +def test_parse_argparser(): + args = Namespace(a=1, b='test', c=None, d='not valid') + new_args = parse_argparser(ArgparseExample, args) + assert new_args.a == 1 + assert new_args.b == 'test' + assert new_args.c + assert new_args.d == 'not valid' + + +def test_parse_args_from_docstring_normal(): + args_help = parse_args_from_docstring( + """Constrain image dataset + + Args: + root: Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + normalize: mean and std deviation of the MNIST dataset. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + num_samples: number of examples per selected class/digit + digits: list selected MNIST digits/classes + + Examples: + >>> dataset = TrialMNIST(download=True) + >>> len(dataset) + 300 + >>> sorted(set([d.item() for d in dataset.targets])) + [0, 1, 2] + >>> torch.bincount(dataset.targets) + tensor([100, 100, 100]) + """ + ) + + expected_args = ['root', 'train', 'normalize', 'download', 'num_samples', 'digits'] + assert len(args_help.keys()) == len(expected_args) + assert all([x == y for x, y in zip(args_help.keys(), expected_args)]) + assert args_help['root'] == 'Root directory of dataset where ``MNIST/processed/training.pt``' \ + ' and ``MNIST/processed/test.pt`` exist.' + assert args_help['normalize'] == 'mean and std deviation of the MNIST dataset.' + + +def test_parse_args_from_docstring_empty(): + args_help = parse_args_from_docstring( + """Constrain image dataset + + Args: + + Returns: + + Examples: + """ + ) + assert len(args_help.keys()) == 0 + + +def test_get_abbrev_qualified_cls_name(): + assert get_abbrev_qualified_cls_name(Trainer) == "pl.Trainer" + + class NestedClass: + pass + + assert not __name__.startswith("pytorch_lightning.") + expected_name = f"{__name__}.test_get_abbrev_qualified_cls_name..NestedClass" + assert get_abbrev_qualified_cls_name(NestedClass) == expected_name + + +class AddArgparseArgsExampleClass: + """ + Args: + my_parameter: A thing. + """ + + def __init__(self, my_parameter: int = 0): + pass + + @staticmethod + def get_deprecated_arg_names() -> List[str]: + return [] + + +class AddArgparseArgsExampleClassViaInit: + + def __init__(self, my_parameter: int = 0): + """ + Args: + my_parameter: A thing. + """ + pass + + +class AddArgparseArgsExampleClassNoDoc: + + def __init__(self, my_parameter: int = 0): + pass + + +def extract_help_text(parser): + help_str_buffer = io.StringIO() + parser.print_help(file=help_str_buffer) + help_str_buffer.seek(0) + return help_str_buffer.read() + + +@pytest.mark.parametrize(["cls", "name"], [ + [AddArgparseArgsExampleClass, "AddArgparseArgsExampleClass"], + [AddArgparseArgsExampleClassViaInit, "AddArgparseArgsExampleClassViaInit"], + [AddArgparseArgsExampleClassNoDoc, "AddArgparseArgsExampleClassNoDoc"], +]) +def test_add_argparse_args(cls, name): + """ + Tests that ``add_argparse_args`` handles argument groups correctly, and + can be parsed. + """ + parser = ArgumentParser() + parser_main = parser.add_argument_group("main") + parser_main.add_argument("--main_arg", type=str, default="") + parser_old = parser # For testing. + parser = add_argparse_args(cls, parser) + assert parser is parser_old + + # Check nominal argument groups. + help_text = extract_help_text(parser) + assert "main:" in help_text + assert "--main_arg" in help_text + assert f"{name}:" in help_text + assert "--my_parameter" in help_text + if cls is not AddArgparseArgsExampleClassNoDoc: + assert "A thing" in help_text + + fake_argv = ["--main_arg=abc", "--my_parameter=2"] + args = parser.parse_args(fake_argv) + assert args.main_arg == "abc" + assert args.my_parameter == 2 + + +def test_negative_add_argparse_args(): + with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."): + parser = ArgumentParser() + add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow")) + + +def test_add_argparse_args_no_argument_group(): + """ + Tests that ``add_argparse_args(..., use_argument_group=False)`` (old + workflow) handles argument groups correctly, and can be parsed. + """ + parser = ArgumentParser() + parser.add_argument("--main_arg", type=str, default="") + parser_old = parser # For testing. + parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False) + assert parser is not parser_old + + # Check arguments. + help_text = extract_help_text(parser) + assert "--main_arg" in help_text + assert "--my_parameter" in help_text + assert "AddArgparseArgsExampleClass:" not in help_text + + fake_argv = ["--main_arg=abc", "--my_parameter=2"] + args = parser.parse_args(fake_argv) + assert args.main_arg == "abc" + assert args.my_parameter == 2 + + +def test_gpus_arg_default(): + assert _gpus_arg_default('1,2') == '1,2' + assert _gpus_arg_default('1') == 1 + + +def test_int_or_float_type(): + assert isinstance(_int_or_float_type('0.0'), float) + assert isinstance(_int_or_float_type('0'), int) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py new file mode 100644 index 00000000000000..aba35c95e2a6ce --- /dev/null +++ b/tests/utilities/test_dtype_device_mixin.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.nn as nn + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class SubSubModule(DeviceDtypeModuleMixin): + pass + + +class SubModule(nn.Module): + + def __init__(self): + super().__init__() + self.module = SubSubModule() + + +class TopModule(BoringModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.module = SubModule() + + +class DeviceAssertCallback(Callback): + + def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx): + rank = trainer.local_rank + assert isinstance(model, TopModule) + # index = None also means first device + assert (model.device.index is None and rank == 0) or model.device.index == rank + assert model.device == model.module.module.device + + +@pytest.mark.parametrize(['dst_dtype'], [ + pytest.param(torch.float), + pytest.param(torch.double), + pytest.param(torch.half), +]) +@pytest.mark.parametrize(['dst_device'], [ + pytest.param(torch.device('cpu')), + pytest.param(torch.device('cuda', 0)), +]) +@RunIf(min_gpus=1) +def test_submodules_device_and_dtype(dst_device, dst_dtype): + """ + Test that the device and dtype property updates propagate through mixed nesting of regular + nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule). + """ + + model = TopModule() + assert model.device == torch.device('cpu') + model = model.to(device=dst_device, dtype=dst_dtype) + # nn.Module does not have these attributes + assert not hasattr(model.module, '_device') + assert not hasattr(model.module, '_dtype') + # device and dtype change should propagate down into all children + assert model.device == model.module.module.device == dst_device + assert model.dtype == model.module.module.dtype == dst_dtype + + +@RunIf(min_gpus=2) +def test_submodules_multi_gpu_dp(tmpdir): + model = TopModule() + trainer = Trainer( + default_root_dir=tmpdir, + accelerator='dp', + gpus=2, + callbacks=[DeviceAssertCallback()], + max_steps=1, + ) + trainer.fit(model) + + +@RunIf(min_gpus=2) +def test_submodules_multi_gpu_ddp_spawn(tmpdir): + model = TopModule() + trainer = Trainer( + default_root_dir=tmpdir, + accelerator='ddp_spawn', + gpus=2, + callbacks=[DeviceAssertCallback()], + max_steps=1, + ) + trainer.fit(model) + + +@pytest.mark.parametrize( + ['device'], + [ + pytest.param(None), # explicitly call without an index to see if the returning device contains an index + pytest.param(0), + pytest.param(torch.device('cuda', 0)), + ] +) +@RunIf(min_gpus=1) +def test_gpu_cuda_device(device): + model = TopModule() + + model.cuda(device) + + device = model.device + assert device.type == 'cuda' + assert device.index is not None + assert device.index == torch.cuda.current_device() diff --git a/tests/utilities/test_imports.py b/tests/utilities/test_imports.py new file mode 100644 index 00000000000000..e1c494fe4754bb --- /dev/null +++ b/tests/utilities/test_imports.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities import _module_available + + +def test_module_exists(): + """Test if the some 3rd party libs are available""" + assert _module_available("torch") + assert _module_available("torch.nn.parallel") + assert not _module_available("torch.nn.asdf") + assert not _module_available("asdf") + assert not _module_available("asdf.bla.asdf") diff --git a/tests/utilities/test_memory.py b/tests/utilities/test_memory.py new file mode 100644 index 00000000000000..1c90423a27c83f --- /dev/null +++ b/tests/utilities/test_memory.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.utilities.memory import recursive_detach + + +def test_recursive_detach(): + device = "cuda" if torch.cuda.is_available() else "cpu" + x = {"foo": torch.tensor(0, device=device), "bar": {"baz": torch.tensor(1.0, device=device, requires_grad=True)}} + y = recursive_detach(x, to_cpu=True) + + assert x["foo"].device.type == device + assert x["bar"]["baz"].device.type == device + assert x["bar"]["baz"].requires_grad + + assert y["foo"].device.type == "cpu" + assert y["bar"]["baz"].device.type == "cpu" + assert not y["bar"]["baz"].requires_grad diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py new file mode 100644 index 00000000000000..6ea10adf3d6966 --- /dev/null +++ b/tests/utilities/test_parsing.py @@ -0,0 +1,318 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +import pytest + +from pytorch_lightning.utilities.parsing import ( + AttributeDict, + clean_namespace, + collect_init_args, + flatten_dict, + get_init_args, + is_picklable, + lightning_getattr, + lightning_hasattr, + lightning_setattr, + parse_class_init_keys, + str_to_bool, + str_to_bool_or_str, +) + +unpicklable_function = lambda: None + + +@pytest.fixture(scope="module") +def model_cases(): + + class TestHparamsNamespace: + learning_rate = 1 + + def __contains__(self, item): + return item == "learning_rate" + + TestHparamsDict = {'learning_rate': 2} + + class TestModel1: # test for namespace + learning_rate = 0 + + model1 = TestModel1() + + class TestModel2: # test for hparams namespace + hparams = TestHparamsNamespace() + + model2 = TestModel2() + + class TestModel3: # test for hparams dict + hparams = TestHparamsDict + + model3 = TestModel3() + + class TestModel4: # fail case + batch_size = 1 + + model4 = TestModel4() + + class DataModule: + batch_size = 8 + + class Trainer: + datamodule = DataModule + + class TestModel5: # test for datamodule + trainer = Trainer + + model5 = TestModel5() + + class TestModel6: # test for datamodule w/ hparams w/o attribute (should use datamodule) + trainer = Trainer + hparams = TestHparamsDict + + model6 = TestModel6() + + TestHparamsDict2 = {'batch_size': 2} + + class TestModel7: # test for datamodule w/ hparams w/ attribute (should use datamodule) + trainer = Trainer + hparams = TestHparamsDict2 + + model7 = TestModel7() + + return model1, model2, model3, model4, model5, model6, model7 + + +def test_lightning_hasattr(tmpdir, model_cases): + """Test that the lightning_hasattr works in all cases""" + model1, model2, model3, model4, model5, model6, model7 = models = model_cases + assert lightning_hasattr(model1, 'learning_rate'), \ + 'lightning_hasattr failed to find namespace variable' + assert lightning_hasattr(model2, 'learning_rate'), \ + 'lightning_hasattr failed to find hparams namespace variable' + assert lightning_hasattr(model3, 'learning_rate'), \ + 'lightning_hasattr failed to find hparams dict variable' + assert not lightning_hasattr(model4, 'learning_rate'), \ + 'lightning_hasattr found variable when it should not' + assert lightning_hasattr(model5, 'batch_size'), \ + 'lightning_hasattr failed to find batch_size in datamodule' + assert lightning_hasattr(model6, 'batch_size'), \ + 'lightning_hasattr failed to find batch_size in datamodule w/ hparams present' + assert lightning_hasattr(model7, 'batch_size'), \ + 'lightning_hasattr failed to find batch_size in hparams w/ datamodule present' + + for m in models: + assert not lightning_hasattr(m, "this_attr_not_exist") + + +def test_lightning_getattr(tmpdir, model_cases): + """Test that the lightning_getattr works in all cases""" + models = model_cases + for i, m in enumerate(models[:3]): + value = lightning_getattr(m, 'learning_rate') + assert value == i, 'attribute not correctly extracted' + + model5, model6, model7 = models[4:] + assert lightning_getattr(model5, 'batch_size') == 8, \ + 'batch_size not correctly extracted' + assert lightning_getattr(model6, 'batch_size') == 8, \ + 'batch_size not correctly extracted' + assert lightning_getattr(model7, 'batch_size') == 8, \ + 'batch_size not correctly extracted' + + for m in models: + with pytest.raises( + AttributeError, + match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." + ): + lightning_getattr(m, "this_attr_not_exist") + + +def test_lightning_setattr(tmpdir, model_cases): + """Test that the lightning_setattr works in all cases""" + models = model_cases + for m in models[:3]: + lightning_setattr(m, 'learning_rate', 10) + assert lightning_getattr(m, 'learning_rate') == 10, \ + 'attribute not correctly set' + + model5, model6, model7 = models[4:] + lightning_setattr(model5, 'batch_size', 128) + lightning_setattr(model6, 'batch_size', 128) + lightning_setattr(model7, 'batch_size', 128) + assert lightning_getattr(model5, 'batch_size') == 128, \ + 'batch_size not correctly set' + assert lightning_getattr(model6, 'batch_size') == 128, \ + 'batch_size not correctly set' + assert lightning_getattr(model7, 'batch_size') == 128, \ + 'batch_size not correctly set' + + for m in models: + with pytest.raises( + AttributeError, + match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." + ): + lightning_setattr(m, "this_attr_not_exist", None) + + +def test_str_to_bool_or_str(tmpdir): + true_cases = ['y', 'yes', 't', 'true', 'on', '1'] + false_cases = ['n', 'no', 'f', 'false', 'off', '0'] + other_cases = ['yyeess', 'noooo', 'lightning'] + + for case in true_cases: + assert str_to_bool_or_str(case) is True + + for case in false_cases: + assert str_to_bool_or_str(case) is False + + for case in other_cases: + assert str_to_bool_or_str(case) == case + + +def test_str_to_bool(tmpdir): + true_cases = ['y', 'yes', 't', 'true', 'on', '1'] + false_cases = ['n', 'no', 'f', 'false', 'off', '0'] + other_cases = ['yyeess', 'noooo', 'lightning'] + + for case in true_cases: + assert str_to_bool(case) is True + + for case in false_cases: + assert str_to_bool(case) is False + + for case in other_cases: + with pytest.raises(ValueError): + str_to_bool(case) + + +def test_is_picklable(tmpdir): + # See the full list of picklable types at + # https://docs.python.org/3/library/pickle.html#pickle-picklable + class UnpicklableClass: + # Only classes defined at the top level of a module are picklable. + pass + + true_cases = [None, True, 123, "str", (123, "str"), max] + false_cases = [unpicklable_function, UnpicklableClass] + + for case in true_cases: + assert is_picklable(case) is True + + for case in false_cases: + assert is_picklable(case) is False + + +def test_clean_namespace(tmpdir): + # See the full list of picklable types at + # https://docs.python.org/3/library/pickle.html#pickle-picklable + class UnpicklableClass: + # Only classes defined at the top level of a module are picklable. + pass + + test_case = { + "1": None, + "2": True, + "3": 123, + "4": unpicklable_function, + "5": UnpicklableClass, + } + + clean_namespace(test_case) + + assert test_case == {"1": None, "2": True, "3": 123} + + +def test_parse_class_init_keys(tmpdir): + + class Class: + + def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): + pass + + assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs") + + +def test_get_init_args(tmpdir): + + class AutomaticArgsModel: + + def __init__(self, anyarg, anykw=42, **kwargs): + super().__init__() + + self.get_init_args_wrapper() + + def get_init_args_wrapper(self): + frame = inspect.currentframe().f_back + self.result = get_init_args(frame) + + my_class = AutomaticArgsModel("test", anykw=32, otherkw=123) + assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123} + + my_class.get_init_args_wrapper() + assert my_class.result == {} + + +def test_collect_init_args(): + + class AutomaticArgsParent: + + def __init__(self, anyarg, anykw=42, **kwargs): + super().__init__() + self.get_init_args_wrapper() + + def get_init_args_wrapper(self): + frame = inspect.currentframe() + self.result = collect_init_args(frame, []) + + class AutomaticArgsChild(AutomaticArgsParent): + + def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs): + super().__init__(anyarg, anykw=anykw, **kwargs) + + my_class = AutomaticArgsChild("test1", "test2", anykw=32, childkw=22, otherkw=123) + assert my_class.result[0] == {"anyarg": "test1", "anykw": 32, "otherkw": 123} + assert my_class.result[1] == {"anyarg": "test1", "childarg": "test2", "anykw": 32, "childkw": 22, "otherkw": 123} + + +def test_attribute_dict(tmpdir): + # Test initialization + inputs = { + 'key1': 1, + 'key2': 'abc', + } + ad = AttributeDict(inputs) + for key, value in inputs.items(): + assert getattr(ad, key) == value + + # Test adding new items + ad = AttributeDict() + ad.update({'key1': 1}) + assert ad.key1 == 1 + + # Test updating existing items + ad = AttributeDict({'key1': 1}) + ad.key1 = 123 + assert ad.key1 == 123 + + +def test_flatten_dict(tmpdir): + d = {'1': 1, '_': {'2': 2, '_': {'3': 3, '4': 4}}} + + expected = { + '1': 1, + '2': 2, + '3': 3, + '4': 4, + } + + assert flatten_dict(d) == expected diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py new file mode 100644 index 00000000000000..74c6674eec7931 --- /dev/null +++ b/tests/utilities/test_seed.py @@ -0,0 +1,55 @@ +import os +from unittest import mock + +import pytest + +import pytorch_lightning.utilities.seed as seed_utils + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_seed_stays_same_with_multiple_seed_everything_calls(): + """ + Ensure that after the initial seed everything, + the seed stays the same for the same run. + """ + with pytest.warns(UserWarning, match="No correct seed found"): + seed_utils.seed_everything() + initial_seed = os.environ.get("PL_GLOBAL_SEED") + + with pytest.warns(None) as record: + seed_utils.seed_everything() + assert not record # does not warn + seed = os.environ.get("PL_GLOBAL_SEED") + + assert initial_seed == seed + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) +def test_correct_seed_with_environment_variable(): + """ + Ensure that the PL_GLOBAL_SEED environment is read + """ + assert seed_utils.seed_everything() == 2020 + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) +@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) +def test_invalid_seed(): + """ + Ensure that we still fix the seed even if an invalid seed is given + """ + with pytest.warns(UserWarning, match="No correct seed found"): + seed = seed_utils.seed_everything() + assert seed == 123 + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) +@pytest.mark.parametrize("seed", (10e9, -10e9)) +def test_out_of_bounds_seed(seed): + """ + Ensure that we still fix the seed even if an out-of-bounds seed is given + """ + with pytest.warns(UserWarning, match="is not in bounds"): + actual = seed_utils.seed_everything(seed) + assert actual == 123 diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py new file mode 100644 index 00000000000000..82801cb27c407e --- /dev/null +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch + +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint + + +@pytest.mark.parametrize( + "old_checkpoint, new_checkpoint", + [ + ( + { + "epoch": 1, + "global_step": 23, + "checkpoint_callback_best": 0.34 + }, + { + "epoch": 1, + "global_step": 23, + "callbacks": { + ModelCheckpoint: { + "best_model_score": 0.34 + } + } + }, + ), + ( + { + "epoch": 1, + "global_step": 23, + "checkpoint_callback_best_model_score": 0.99 + }, + { + "epoch": 1, + "global_step": 23, + "callbacks": { + ModelCheckpoint: { + "best_model_score": 0.99 + } + } + }, + ), + ( + { + "epoch": 1, + "global_step": 23, + "checkpoint_callback_best_model_path": 'path' + }, + { + "epoch": 1, + "global_step": 23, + "callbacks": { + ModelCheckpoint: { + "best_model_path": 'path' + } + } + }, + ), + ( + { + "epoch": 1, + "global_step": 23, + "early_stop_callback_wait": 2, + "early_stop_callback_patience": 4 + }, + { + "epoch": 1, + "global_step": 23, + "callbacks": { + EarlyStopping: { + "wait_count": 2, + "patience": 4 + } + } + }, + ), + ], +) +def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): + filepath = os.path.join(tmpdir, "model.ckpt") + torch.save(old_checkpoint, filepath) + upgrade_checkpoint(filepath) + updated_checkpoint = torch.load(filepath) + assert updated_checkpoint == new_checkpoint diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py new file mode 100644 index 00000000000000..edca2777b578ab --- /dev/null +++ b/tests/utilities/test_xla_device_utils.py @@ -0,0 +1,53 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from unittest.mock import patch + +import pytest + +import pytorch_lightning.utilities.xla_device as xla_utils +from pytorch_lightning.utilities import _XLA_AVAILABLE +from tests.helpers.runif import RunIf + + +@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") +def test_tpu_device_absence(): + """Check tpu_device_exists returns False when torch_xla is not available""" + assert not xla_utils.XLADeviceUtils.tpu_device_exists() + + +@RunIf(tpu=True) +def test_tpu_device_presence(): + """Check tpu_device_exists returns True when TPU is available""" + assert xla_utils.XLADeviceUtils.tpu_device_exists() + + +def sleep_fn(sleep_time: float) -> bool: + time.sleep(sleep_time) + return True + + +@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3) +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") +def test_result_returns_within_timeout_seconds(): + """Check that pl_multi_process returns within 3 seconds""" + fn = xla_utils.pl_multi_process(sleep_fn) + + start = time.time() + result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5) + end = time.time() + elapsed_time = int(end - start) + + assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT + assert result