diff --git a/.circleci/build.sh b/.circleci/build.sh index 25cc78f3c6b..68e81a5436c 100755 --- a/.circleci/build.sh +++ b/.circleci/build.sh @@ -41,14 +41,17 @@ apply_patches python -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)" +# We always build PyTorch without CUDA support. +export USE_CUDA=0 python setup.py install sccache --show-stats source $XLA_DIR/xla_env export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json" -export SILO_NAME='cache-silo-ci-gcc-11' # cache bucket for CI +export SILO_NAME='cache-silo-ci-dev-3.8_cuda_12.1' # cache bucket for CI export BUILD_CPP_TESTS='1' +export TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_70,sm_75,compute_80,$TF_CUDA_COMPUTE_CAPABILITIES" build_torch_xla $XLA_DIR popd diff --git a/.circleci/common.sh b/.circleci/common.sh index 88c5fed8efc..235086cba41 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -92,27 +92,6 @@ function install_deps_pytorch_xla() { sudo ln -s "$(command -v bazelisk)" /usr/bin/bazel - # Install gcc-11 - sudo apt-get update - # Update ppa for GCC - sudo apt-get install -y software-properties-common - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test - sudo apt update -y - sudo apt install -y gcc-11 - sudo apt install -y g++-11 - sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 100 - - export NVCC_PREPEND_FLAGS='-ccbin /usr/bin/g++-11' - - # Hack similar to https://github.com/pytorch/pytorch/pull/105227/files#diff-9e59213240d3b55d2ddc53c8c096db9eece0665d64f46473454f9dc0c10fd804 - sudo rm /opt/conda/lib/libstdc++.so* - - # Update gcov for test coverage - sudo update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-11 100 - sudo update-alternatives --install /usr/bin/gcov-dump gcov-dump /usr/bin/gcov-dump-11 100 - sudo update-alternatives --install /usr/bin/gcov-tool gcov-tool /usr/bin/gcov-tool-11 100 - # Symnlink the missing cuda headers if exists CUBLAS_PATTERN="/usr/include/cublas*" if ls $CUBLAS_PATTERN 1> /dev/null 2>&1; then @@ -148,16 +127,18 @@ function run_torch_xla_python_tests() { else ./test/run_tests.sh - # GPU tests + # CUDA tests if [ -x "$(command -v nvidia-smi)" ]; then - # These tests fail on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) + # These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) + PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + # TODO(xiowei replace gpu with cuda): remove the test below with PJRT_DEVICE=GPU because PJRT_DEVICE=GPU is being deprecated. PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 - PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 - XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 + XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 # Syncfree SGD optimizer tests if [ -d ./torch_xla/amp/syncfree ]; then echo "Running Syncfree Optimizer Test" - PJRT_DEVICE=GPU python test/test_syncfree_optimizers.py + PJRT_DEVICE=CUDA python test/test_syncfree_optimizers.py # Following test scripts are mainly useful for # performance evaluation & comparison among different @@ -192,9 +173,9 @@ function run_torch_xla_cpp_tests() { if [ "$USE_COVERAGE" != "0" ]; then # TODO(yeounoh) shard the coverage testing if [ -x "$(command -v nvidia-smi)" ]; then - PJRT_DEVICE=GPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" + PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L"" cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov1.dat - PJRT_DEVICE=GPU test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS + PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov2.dat lcov --add-tracefile /tmp/cov1.dat -a /tmp/cov2.dat -o /tmp/merged.dat else @@ -206,8 +187,8 @@ function run_torch_xla_cpp_tests() { else # Shard GPU testing if [ -x "$(command -v nvidia-smi)" ]; then - PJRT_DEVICE=GPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" - PJRT_DEVICE=GPU test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS + PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L"" + PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS else PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" fi diff --git a/.circleci/docker/Dockerfile b/.circleci/docker/Dockerfile index 89f6dbc08db..f0cd196511c 100644 --- a/.circleci/docker/Dockerfile +++ b/.circleci/docker/Dockerfile @@ -1,13 +1,13 @@ # This requires cuda & cudnn packages pre-installed in the base image. # Other available cuda images are listed at https://hub.docker.com/r/nvidia/cuda -ARG base_image="nvidia/cuda:11.7.0-cudnn8-devel-ubuntu18.04" +ARG base_image="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1" FROM "${base_image}" ARG python_version="3.8" ARG cuda="1" ARG cuda_compute="5.2,7.5" -ARG cc="clang-8" -ARG cxx="clang++-8" +ARG cc="clang" +ARG cxx="clang++" ARG cxx_abi="1" ARG tpuvm="" @@ -37,38 +37,15 @@ ENV CXX "${cxx}" # Whether to build for TPUVM mode ENV TPUVM_MODE "${tpuvm}" -# Rotate nvidia repo public key (last updated: 04/27/2022) -# Unfortunately, nvidia/cuda image is shipped with invalid public key -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub - -# Install base system packages -RUN apt-get clean && apt-get update -RUN apt-get upgrade -y -RUN apt-get install --fix-missing -y python-pip python3-pip git curl libopenblas-dev vim jq \ - apt-transport-https ca-certificates procps openssl sudo wget libssl-dev libc6-dbg - -# Install clang & llvm -ADD ./install_llvm_clang.sh install_llvm_clang.sh -RUN bash ./install_llvm_clang.sh - +# Install clang as upstream CI forces clang +RUN apt-get install -y clang # Install valgrind -ADD ./install_valgrind.sh install_valgrind.sh +COPY ./install_valgrind.sh install_valgrind.sh RUN bash ./install_valgrind.sh -# Sets up jenkins user. -RUN useradd jenkins && \ - mkdir /home/jenkins && \ - chown jenkins /home/jenkins -RUN echo 'jenkins ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers - -RUN mkdir -p /opt/conda /opt/cargo /opt/rustup /workspace /var/lib/jenkins && \ - chown jenkins /opt/conda /opt/cargo /opt/rustup /workspace /var/lib/jenkins -USER jenkins -WORKDIR /workspace - # Install openmpi for CUDA -run sudo apt-get install -y ssh -run sudo apt-get install -y --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev +run apt-get install -y ssh +run apt-get install -y --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev # Builds and configure sccache ENV OPENSSL_INCLUDE_DIR /usr/include/openssl @@ -87,6 +64,25 @@ RUN . $CARGO_HOME/env && \ ENV PATH $CARGO_HOME/bin:$PATH +# Upstream CI requires jq +RUN apt-get install -y jq + +# TODO: Add exec permisson for all users in base image. +RUN chmod a+x /usr/local/bin/bazel +# TODO: move sudo installation in base image. +RUN apt-get install -y sudo + +RUN useradd jenkins && \ + mkdir /home/jenkins && \ + chown jenkins /home/jenkins +RUN echo 'jenkins ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers + +RUN mkdir -p /opt/conda /opt/cargo /opt/rustup /workspace /var/lib/jenkins && \ + chown jenkins /opt/conda /opt/cargo /opt/rustup /workspace /var/lib/jenkins +ENV PATH /home/jenkins/.local/bin:$PATH +USER jenkins +WORKDIR /workspace + # Installs and configures Conda. ADD ./install_conda.sh install_conda.sh RUN sudo chown jenkins ./install_conda.sh @@ -95,6 +91,7 @@ RUN bash ./install_conda.sh "${python_version}" /opt/conda RUN echo "conda activate base" >> ~/.bashrc RUN echo "export TF_CPP_LOG_THREAD_ID=1" >> ~/.bashrc ENV PATH /opt/conda/bin:$PATH +ENV LD_LIBRARY_PATH /lib/x86_64-linux-gnu/:/usr/lib/x86_64-linux-gnu/:/opt/conda/lib/:$LD_LIBRARY_PATH RUN bash -c "source ~/.bashrc" CMD ["bash"] diff --git a/.circleci/docker/install_conda.sh b/.circleci/docker/install_conda.sh index b0fc17c73ec..15e2c541b25 100644 --- a/.circleci/docker/install_conda.sh +++ b/.circleci/docker/install_conda.sh @@ -4,7 +4,7 @@ set -ex PYTHON_VERSION=$1 CONDA_PREFIX=$2 -DEFAULT_PYTHON_VERSION=3.7 +DEFAULT_PYTHON_VERSION=3.8 function install_and_setup_conda() { @@ -30,7 +30,7 @@ function install_and_setup_conda() { conda update -y -n base conda conda install -y python=$PYTHON_VERSION - conda install -y nomkl numpy=1.18.5 pyyaml setuptools cmake \ + conda install -y nomkl numpy=1.18.5 pyyaml setuptools \ cffi typing tqdm coverage hypothesis dataclasses cython /usr/bin/yes | pip install mkl==2022.2.1 @@ -41,9 +41,6 @@ function install_and_setup_conda() { /usr/bin/yes | pip install --upgrade numba /usr/bin/yes | pip install cloud-tpu-client /usr/bin/yes | pip install expecttest==0.1.3 - /usr/bin/yes | pip install ninja # Install ninja to speedup the build - # Using Ninja requires CMake>=3.13, PyTorch requires CMake>=3.18 - /usr/bin/yes | pip install "cmake>=3.18" --upgrade /usr/bin/yes | pip install absl-py # Additional PyTorch requirements /usr/bin/yes | pip install scikit-image scipy==1.6.3 diff --git a/.circleci/docker/install_valgrind.sh b/.circleci/docker/install_valgrind.sh old mode 100644 new mode 100755 index e235d36609b..08e45fd0e28 --- a/.circleci/docker/install_valgrind.sh +++ b/.circleci/docker/install_valgrind.sh @@ -9,7 +9,7 @@ tar -xjf valgrind-${VALGRIND_VERSION}.tar.bz2 cd valgrind-${VALGRIND_VERSION} ./configure --prefix=/usr/local make -j6 -sudo make install +make install cd ../../ rm -rf valgrind_build alias valgrind="/usr/local/bin/valgrind" diff --git a/.circleci/test.sh b/.circleci/test.sh index 914d56d206f..127c7f497a1 100755 --- a/.circleci/test.sh +++ b/.circleci/test.sh @@ -26,5 +26,5 @@ function install_torchvision() { install_torchvision export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json" -export SILO_NAME='cache-silo-ci-gcc-11' # cache bucket for CI +export SILO_NAME='cache-silo-ci-dev-3.8_cuda_12.1' # cache bucket for CI run_torch_xla_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml index 74b1b00397a..879594476ef 100644 --- a/.github/workflows/_build.yml +++ b/.github/workflows/_build.yml @@ -73,12 +73,12 @@ jobs: # if image layers are not present in the repo. # Note: disable the following 2 lines while testing a new image, so we do not # push to the upstream. - docker tag "${GCR_DOCKER_IMAGE}" "${ECR_DOCKER_IMAGE_BASE}:v1.0" >/dev/null - docker push "${ECR_DOCKER_IMAGE_BASE}:v1.0" >/dev/null + docker tag "${GCR_DOCKER_IMAGE}" "${ECR_DOCKER_IMAGE_BASE}:v1.1-lite" >/dev/null + docker push "${ECR_DOCKER_IMAGE_BASE}:v1.1-lite" >/dev/null - name: Start the container shell: bash run: | - pid=$(docker run -t -d -w "$WORKDIR" "${GCR_DOCKER_IMAGE}") + pid=$(docker run --privileged -t -d -w "$WORKDIR" "${GCR_DOCKER_IMAGE}") docker exec -u jenkins "${pid}" sudo chown -R jenkins "${WORKDIR}" docker cp "${GITHUB_WORKSPACE}/." "$pid:$WORKDIR" echo "pid=${pid}" >> "${GITHUB_ENV}" @@ -87,7 +87,6 @@ jobs: shell: bash run: | echo "declare -x SCCACHE_BUCKET=${SCCACHE_BUCKET}" | docker exec -i "${pid}" sh -c "cat >> env" - echo "declare -x CC=clang-8 CXX=clang++-8" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x DISABLE_XRT=${DISABLE_XRT}" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x XLA_CUDA=${XLA_CUDA}" | docker exec -i "${pid}" sh -c "cat >> xla_env" echo "declare -x BAZEL_REMOTE_CACHE=1" | docker exec -i "${pid}" sh -c "cat >> xla_env" @@ -96,8 +95,7 @@ jobs: - name: Build shell: bash run: | - docker exec -u jenkins "${pid}" bash -c ". ~/.bashrc && .circleci/build.sh" - + docker exec --privileged -u jenkins "${pid}" bash -c ".circleci/build.sh" - name: Cleanup build env shell: bash run: | diff --git a/.github/workflows/_coverage.yml b/.github/workflows/_coverage.yml index 4643e225314..e114074bb7e 100644 --- a/.github/workflows/_coverage.yml +++ b/.github/workflows/_coverage.yml @@ -94,7 +94,7 @@ jobs: - name: Test shell: bash run: | - docker exec -u jenkins "${pid}" bash -c '. ~/.bashrc && .circleci/${{ inputs.test-script }}' + docker exec -u jenkins "${pid}" bash -c '.circleci/${{ inputs.test-script }}' - name: Upload coverage results if: ${{ inputs.collect-coverage }} shell: bash diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index d1bba0962e9..ed9a4ab0ea9 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -43,7 +43,7 @@ jobs: echo "pid=${pid}" >> "${GITHUB_ENV}" - name: Build & publish docs shell: bash - run: docker exec -u jenkins "${pid}" bash -c '. ~/.bashrc && .circleci/doc_push.sh' + run: docker exec -u jenkins "${pid}" bash -c '.circleci/doc_push.sh' - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 7c7215a573a..3f0aa8acd47 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -116,7 +116,7 @@ jobs: - name: Test shell: bash run: | - docker exec -u jenkins "${pid}" bash -c '. ~/.bashrc && .circleci/${{ inputs.test-script }}' + docker exec --privileged -u jenkins "${pid}" bash -c '.circleci/${{ inputs.test-script }}' - name: Upload coverage results if: ${{ inputs.collect-coverage }} shell: bash diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 83277c8c96a..31b415c503e 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -19,8 +19,7 @@ jobs: uses: ./.github/workflows/_build.yml with: ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base - gcr-docker-image: gcr.io/tpu-pytorch/xla_base:latest - disable_xrt: 1 + gcr-docker-image: gcr.io/tpu-pytorch/xla_base:dev-3.8_cuda_12.1 cuda: 1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} @@ -43,7 +42,7 @@ jobs: with: docker-image: ${{ needs.build.outputs.docker-image }} runner: linux.8xlarge.nvidia.gpu - timeout-minutes: 300 + timeout-minutes: 180 disable-xrt: 1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/.github/workflows/build_and_test_xrt.yml b/.github/workflows/build_and_test_xrt.yml index dd3f95b7100..79f96e0c19c 100644 --- a/.github/workflows/build_and_test_xrt.yml +++ b/.github/workflows/build_and_test_xrt.yml @@ -18,7 +18,7 @@ jobs: uses: ./.github/workflows/_build.yml with: ecr-docker-image-base: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base - gcr-docker-image: gcr.io/tpu-pytorch/xla_base:latest + gcr-docker-image: gcr.io/tpu-pytorch/xla_base:dev-3.8_cuda_12.1 disable_xrt: 0 cuda: 1 secrets: @@ -42,7 +42,7 @@ jobs: with: docker-image: ${{ needs.build.outputs.docker-image }} runner: linux.8xlarge.nvidia.gpu - timeout-minutes: 300 + timeout-minutes: 180 disable-xrt: 0 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/.github/workflows/tpu-ci.yml b/.github/workflows/tpu-ci.yml new file mode 100644 index 00000000000..3cc48221a46 --- /dev/null +++ b/.github/workflows/tpu-ci.yml @@ -0,0 +1,27 @@ +name: TPU Test +run-name: CI Testing +on: + workflow_dispatch: + schedule: + - cron: '0 16,20,0 * * 1-5' +jobs: + tpu-test: + runs-on: v4-runner-set + steps: + - run: | + git clone --recursive https://github.com/pytorch/pytorch + cd pytorch/ + python3 setup.py install --user + git clone --recursive https://github.com/mbzomowski/xla.git + - env: + BAZEL_VERBOSE: 1 + BUNDLE_LIBTPU: 1 + TPUVM_MODE: 1 + run: | + cd pytorch/xla + python3 setup.py install --user + - env: + PJRT_DEVICE: TPU + run: | + cd pytorch/xla + python3 -u test/test_operations.py -v diff --git a/.kokoro/Dockerfile b/.kokoro/Dockerfile index 32cc499477d..40210aba1f3 100644 --- a/.kokoro/Dockerfile +++ b/.kokoro/Dockerfile @@ -47,7 +47,7 @@ ARG SCCACHE="$(which sccache)" WORKDIR /pytorch/xla ARG GCLOUD_SERVICE_KEY_FILE="/pytorch/xla/default_credentials.json" -ARG SILO_NAME='cache-silo-ci-gcc-11' # cache bucket for CI +ARG SILO_NAME='cache-silo-ci-dev-3.8_cuda_12.1' # cache bucket for CI RUN time pip install -e . # Run tests diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4e9c5372880..9045d1238c2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -94,7 +94,7 @@ To run the tests, follow __one__ of the options below: * Run on GPU: ```Shell - export PJRT_DEVICE=GPU GPU_NUM_DEVICES=${NUM_GPU} + export PJRT_DEVICE=CUDA GPU_NUM_DEVICES=${NUM_GPU} ``` For more detail on configuring the runtime, please refer to [this doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#quickstart) diff --git a/README.md b/README.md index bfda642b2f0..68a67e96c82 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ If you're using `DistributedDataParallel`, make the following changes: Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at [PyTorch.org](http://pytorch.org/xla/). See the [API Guide](API_GUIDE.md) for best practices when writing networks that run on -XLA devices (TPU, GPU, CPU and...). +XLA devices (TPU, CUDA, CPU and...). Our comprehensive user guides are available at: diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index 842deabd186..ca3ddf702c2 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -3,7 +3,52 @@ Note that the information in this section is subject to be removed in future releases of the _PyTorch/XLA_ software, since many of them are peculiar to a given internal implementation which might change. -To diagnose issues, we can use the execution metrics and counters provided by _PyTorch/XLA_ +## Sanity Check +Before performing any in depth debugging, we want to do a sanity check on the installed PyTorch/XLA. + +### Check PyTorch/XLA Version +PyTorch and PyTorch/XLA version should match. Check out our [README](https://github.com/pytorch/xla#getting-started) for more detials on versions available. +``` +vm:~$ python +>>> import torch +>>> import torch_xla +>>> print(torch.__version__) +2.1.0+cu121 +>>> print(torch_xla.__version__) +2.1.0 +``` + +### Perform A Simple Calculation +``` +vm:~$ export PJRT_DEVICE=TPU +vm:~$ python3 +>>> import torch +>>> import torch_xla.core.xla_model as xm +>>> t1 = torch.tensor(100, device=xm.xla_device()) +>>> t2 = torch.tensor(200, device=xm.xla_device()) +>>> print(t1 + t2) +tensor(300, device='xla:0') +``` + +### Run Resnet With Fake Data +For nightly +``` +vm:~$ git clone https://github.com/pytorch/xla.git +vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data +``` + +For release version `x.y`, you want to use the branch `rx.y`. For example if you installed 2.1 release, you should do +``` +vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git +vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data +``` + +If you can get the resnet to run we can conclude that torch_xla is installed correctly. + + +## Performance Debugging + +To diagnose performance issues, we can use the execution metrics and counters provided by _PyTorch/XLA_ The **first thing** to check when model is slow is to generate a metrics report. Metrics report is extremely helpful in diagnosing issues. Please try to include it in your bug @@ -76,7 +121,7 @@ Counter: aten::nonzero If you see `aten::` ops other than `nonzero` and `_local_scalar_dense`, that usually means a missing lowering in PyTorch/XLA. Feel free to open a feature request for it on [GitHub issues](https://github.com/pytorch/xla/issues). -## Clar The Metrics Report +## Clear The Metrics Report If you want to clear the metrics between steps/epochs, you can use ```Python import torch_xla.debug.metrics as met diff --git a/WORKSPACE b/WORKSPACE index a4e4027a67e..ace66355416 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -39,15 +39,14 @@ http_archive( patch_tool = "patch", patches = [ "//openxla_patches:cache_urls.diff", - "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", - "//openxla_patches:pjrt_api_tsl_logging.diff", - "//openxla_patches:pjrt_c_api_dynamic_dimensions.diff", + "//openxla_patches:gpu_race_condition.diff", + "//openxla_patches:f16_abi_clang.diff", + "//openxla_patches:gpu_topk_rewriter.diff", ], - strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18", + strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", urls = [ - "https://github.com/openxla/xla/archive/97a5f819faf9ff793b7ba68ff1f31f74f9459c18.tar.gz", + "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", ], ) diff --git a/codegen/lazy_tensor_generator.py b/codegen/lazy_tensor_generator.py index e285994b10e..54a51e8dfb6 100644 --- a/codegen/lazy_tensor_generator.py +++ b/codegen/lazy_tensor_generator.py @@ -115,7 +115,7 @@ def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str: get_tensorlist="GetTensorList", get_tensor_or_wrap_number="bridge::GetXlaTensorOrCreateForWrappedNumber", try_get_tensor="bridge::TryGetXlaTensor", - metrics_counter='TORCH_LAZY_FN_COUNTER("xla::")', + metrics_counter='TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::")', create_tensor="XLATensor::Create", create_aten_from_ltc_tensor="torch_xla::bridge::AtenFromXlaTensor", tuple_aten_from_ltc_tensors="torch_xla::bridge::TupleAtenFromXlaTensors", diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index bdb6c38e8cc..e0d917db684 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -192,6 +192,7 @@ supported: - floor_divide - fmod.Scalar - fmod.Tensor + - full - gather - gelu - gelu_backward @@ -267,6 +268,7 @@ supported: - pow.Tensor_Scalar - pow.Tensor_Tensor - _prelu_kernel + - _prelu_kernel_backward - prod - prod.dim_int - _propagate_xla_data diff --git a/configuration.yaml b/configuration.yaml index 7d0a86e38aa..b65ed089fce 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -4,7 +4,7 @@ variables: PJRT_DEVICE: description: - Indicates which device is being used with PJRT. It can be either CPU, - TPU, or GPU + TPU, or CUDA type: string PJRT_SELECT_DEFAULT_DEVICE: description: @@ -41,7 +41,7 @@ variables: - Build the xla client with CUDA enabled. type: bool default_value: false - VERSIONED_XLA_BUILD: + GIT_VERSIONED_XLA_BUILD: description: - Creates a versioned build. In particular, appends a git sha to the version number string diff --git a/docs/gpu.md b/docs/gpu.md index 595bcb43bba..02785ce7470 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -18,7 +18,7 @@ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker -sudo docker run --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8 bin/bash +sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8 bin/bash sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash ``` @@ -59,7 +59,7 @@ pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xl In order to run below examples, you need to clone the pytorch/xla repo to access the imagenet example(We already clone it in our docker). ``` -(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1 PJRT_DEVICE=GPU +(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA (pytorch) root@20ab2c7a2d06:/# git clone --recursive https://github.com/pytorch/xla.git (pytorch) root@20ab2c7a2d06:/# python xla/test/test_train_mp_imagenet.py --fake_data ==> Preparing data.. @@ -87,7 +87,7 @@ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit sudo systemctl restart docker -sudo docker run --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_11.8 +sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_11.8 sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash ``` diff --git a/docs/pjrt.md b/docs/pjrt.md index 265c0abcce5..fca27cca683 100644 --- a/docs/pjrt.md +++ b/docs/pjrt.md @@ -194,15 +194,61 @@ for more information. *Warning: GPU support is still highly experimental!* -To use GPUs with PJRT, simply set `PJRT_DEVICE=GPU` and configure +### Single-node GPU training + +To use GPUs with PJRT, simply set `PJRT_DEVICE=CUDA` and configure `GPU_NUM_DEVICES` to the number of devices on the host. For example: ``` -PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 +PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 +``` + +You can also use `torchrun` to initiate the single-node multi-GPU training. For example, + +``` +PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 +``` + +In the above example, `--nnodes` means how many machines (physical machines or VMs) to be used (it is 1 since we do single-node training). `--nproc-per-node` means how many GPU devices to be used. + +### Multi-node GPU training + +**Note that this feature only works for cuda 12+**. Similar to how PyTorch uses multi-node training, you can run the command as below: + +``` +PJRT_DEVICE=CUDA torchrun \ +--nnodes=${NUMBER_GPU_VM} \ +--node_rank=${CURRENT_NODE_RANK} \ +--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \ +--rdzv_endpoint= multinode_training.py +``` + +- `--nnodes`: how many GPU machines to be used. +- `--node_rank`: the index of the current GPU machines. The value can be 0, 1, ..., ${NUMBER_GPU_VM}-1. +- `--nproc_per_node`: the number of GPU devices to be used on the current machine. +- `--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form :. The `host` will be the internal IP address. The port can be any available port on the machine. + +For example, if you want to train on 2 GPU machines: machine_0 and machine_1, on the first GPU machine machine_0, run + +``` +# PJRT_DEVICE=CUDA torchrun \ +--nnodes=2 \ +--node_rank=0 \ +--nproc_per_node=4 \ +--rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 +``` + +On the second GPU machine, run + +``` +# PJRT_DEVICE=CUDA torchrun \ +--nnodes=2 \ +--node_rank=1 \ +--nproc_per_node=4 \ +--rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 ``` -Currently, only a single host is supported, and multi-host GPU cluster support -will be added in an future release. +the difference between the 2 commands above are `--node_rank` and potentially `--nproc_per_node` if you want to use different number of GPU devices on each machine. All the rest are identical. For more information about `torchrun`, please refer to this [page](https://pytorch.org/docs/stable/elastic/run.html). ## Differences from XRT diff --git a/docs/spmd.md b/docs/spmd.md index 8337bbd5af9..61afba530b0 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -33,7 +33,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou ### Simple Example & Sharding Aannotation API -Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output. +Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L452)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output. ```python def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor @@ -46,8 +46,8 @@ import numpy as np import torch import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -import torch_xla.experimental.xla_sharding as xs -from torch_xla.experimental.xla_sharding import Mesh +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import Mesh # Enable XLA SPMD execution mode. xr.use_spmd() @@ -100,11 +100,11 @@ We derive a logical mesh based on this topology to create sub-groups of devices ![alt_text](assets/mesh_spmd2.png "image_tooltip") -We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028df4da388468fa9a41b1f98ea08bfce13b4c63/torch_xla/experimental/xla_sharding.py#L16). The axes of the logical Mesh can be named. Here is an example: +We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L17). The axes of the logical Mesh can be named. Here is an example: ```python import torch_xla.runtime as xr -from torch_xla.experimental.xla_sharding import Mesh +from torch_xla.distributed.spmd import Mesh # Assuming you are running on a TPU host that has 8 devices attached num_devices = xr.global_runtime_device_count() @@ -130,7 +130,7 @@ In general, SPMD programs should create a single mesh and reuse it for all shard Mesh nicely abstracts how the physical device mesh is constructed. Users can arrange devices in any shape and order using the logical mesh. However, one can define a more performant mesh based on the physical topology, especially when it involves Data Center Network (DCN) cross slice connections. HybridMesh creates a mesh which gives good performance out of the box for such multislice environments. It accepts ici\_mesh\_shape and dcn\_mesh\_shape which denote logical mesh shapes of inner and outer network. ```python -from torch_xla.experimental.xla_sharding import HybridMesh +from torch_xla.distributed.spmd import HybridMesh # This example is assuming 2 slices of v4-8. # - ici_mesh_shape: shape of the logical mesh for inner connected devices. @@ -198,10 +198,24 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i * `XLAShardedTensor` is a `torch.Tensor` subclass and works directly with native torch ops and `module.layers`. We use `__torch_dispatch__` to send `XLAShardedTensor` to the XLA backend. PyTorch/XLA retrieves attached sharding annotations to trace the graph and invokes XLA SPMDPartitioner. * Internally, `XLAShardedTensor` (and its global\_tensor input) is backed by `XLATensor` with a special data structure holding references to the sharded device data. * The sharded tensor after lazy execution may be gathered and materialized back to the host as global\_tensor when requested on the host (e.g., printing the value of the global tensor. -* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L111) to return the local shards on addressable devices as List[[XLAShard](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L12)]. +* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L117) to return the local shards on addressable devices as List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]. There is also an ongoing effort to integrate XLAShardedTensor into DistributedTensor API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)]. +### DTensor Integration +PyTorch has prototype-released [DTensor](https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md) in 2.1. +We are integrating PyTorch/XLA SPMD into DTensor API [RFC](https://github.com/pytorch/pytorch/issues/92909). We have a proof-of-concept integration for `distribute_tensor`, which calls `mark_sharding` annotation API to shard a tensor and its computation using XLA: +```python +import torch +from torch.distributed import DeviceMesh, Shard, distribute_tensor + +# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD. +mesh = DeviceMesh("xla", list(range(world_size))) +big_tensor = torch.randn(100000, 88) +my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) +``` + +This feature is experimental and stay tuned for more updates, examples and tutorials in the upcoming releases. ### Sharding-Aware Host-to-Device Data Loading diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index cecb0ec482d..83d4ccdd03a 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -34,6 +34,7 @@ build_env: SILO_NAME: "cache-silo-{{ arch }}-{{ accelerator }}-{{ clang_version }}" DISABLE_XRT: "{{ disable_xrt }}" _GLIBCXX_USE_CXX11_ABI: 0 + GIT_VERSIONED_XLA_BUILD: "{{ nightly_release }}" amd64: ARCH: amd64 diff --git a/infra/tpu-ci/README.md b/infra/tpu-ci/README.md new file mode 100644 index 00000000000..d4c57ad83fb --- /dev/null +++ b/infra/tpu-ci/README.md @@ -0,0 +1,3 @@ +# TPU CI for PyTorch/XLA + +This Terraform configuration will allow this repository to test PR changes on v4 TPUs run on GKE node pools. These v4 TPUs are created whenever required by a workflow job. diff --git a/infra/tpu-ci/main.tf b/infra/tpu-ci/main.tf new file mode 100644 index 00000000000..ad695b2004b --- /dev/null +++ b/infra/tpu-ci/main.tf @@ -0,0 +1,14 @@ +module "v4_arc_cluster" { + source = "./modules/google-arc-v4-container-cluster" + project_id = "tpu-pytorch" + cluster_name = "bzmarke-test-cluster" + cpu_nodepool_name = "cpu-nodepool" + cpu_node_count = 1 + tpu_nodepool_name = "tpu-nodepool" + max_tpu_nodes = 1 + # Don't include `www.` in the URL + # Should be formatted as: "https://github.com/..." + github_repo_url = "https://github.com/mbzomowski/xla" + runner_image = "gcr.io/tpu-pytorch/bzmarke-image:latest" +} +#test diff --git a/infra/tpu-ci/modules/google-arc-v4-container-cluster/README.md b/infra/tpu-ci/modules/google-arc-v4-container-cluster/README.md new file mode 100644 index 00000000000..3892c05bcc0 --- /dev/null +++ b/infra/tpu-ci/modules/google-arc-v4-container-cluster/README.md @@ -0,0 +1,7 @@ +# Cluster creation for TPU CI for PyTorch/XLA + +This module configures: +* A regional GKE cluster +* A CPU node pool +* An autoscaling v4 TPU node pool +* The installation of Actions Runner Controller (ARC) on the GKE cluster diff --git a/infra/tpu-ci/modules/google-arc-v4-container-cluster/arc-values.yaml b/infra/tpu-ci/modules/google-arc-v4-container-cluster/arc-values.yaml new file mode 100644 index 00000000000..b1a3d6caf3b --- /dev/null +++ b/infra/tpu-ci/modules/google-arc-v4-container-cluster/arc-values.yaml @@ -0,0 +1,18 @@ +githubConfigUrl: ${github_repo_url} +githubConfigSecret: github-pat +minRunners: 0 +maxRunners: ${max_tpu_nodes} +template: + spec: + containers: + - name: runner + image: ${runner_image} + command: ["/home/runner/run.sh"] + resources: + limits: + google.com/tpu: 4 + requests: + google.com/tpu: 4 + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice + cloud.google.com/gke-tpu-topology: 2x2x1 diff --git a/infra/tpu-ci/modules/google-arc-v4-container-cluster/main.tf b/infra/tpu-ci/modules/google-arc-v4-container-cluster/main.tf new file mode 100644 index 00000000000..3a51f8aa3ff --- /dev/null +++ b/infra/tpu-ci/modules/google-arc-v4-container-cluster/main.tf @@ -0,0 +1,97 @@ +provider "google" { + project = var.project_id +} + +provider "helm" { + kubernetes { + host = "https://${google_container_cluster.arc_v4_cluster.endpoint}" + token = data.google_client_config.default.access_token + cluster_ca_certificate = base64decode(google_container_cluster.arc_v4_cluster.master_auth.0.cluster_ca_certificate) + } +} + +data "google_client_config" "default" {} + +resource "google_container_cluster" "arc_v4_cluster" { + name = var.cluster_name + location = "us-central2" + + remove_default_node_pool = true + initial_node_count = 1 + + release_channel { + channel = "RAPID" + } + + min_master_version = 1.28 +} + +resource "google_container_node_pool" "arc_v4_cpu_nodes" { + name = var.cpu_nodepool_name + location = "us-central2" + cluster = google_container_cluster.arc_v4_cluster.name + node_count = var.cpu_node_count + + node_config { + oauth_scopes = [ + "https://www.googleapis.com/auth/logging.write", + "https://www.googleapis.com/auth/monitoring", + ] + + machine_type = "n1-standard-1" + } + + management { + auto_upgrade = true + auto_repair = true + } +} + +resource "google_container_node_pool" "arc_v4_tpu_nodes" { + name = var.tpu_nodepool_name + location = "us-central2" + node_locations = ["us-central2-b"] + cluster = google_container_cluster.arc_v4_cluster.name + initial_node_count = 0 + autoscaling { + total_min_node_count = 0 + total_max_node_count = var.max_tpu_nodes + location_policy = "ANY" + } + node_config { + oauth_scopes = [ + "https://www.googleapis.com/auth/logging.write", + "https://www.googleapis.com/auth/monitoring", + ] + machine_type = "ct4p-hightpu-4t" + } + management { + auto_upgrade = true + auto_repair = true + } +} + +resource "helm_release" "arc" { + name = "actions-runner-controller" + chart = "oci://ghcr.io/actions/actions-runner-controller-charts/gha-runner-scale-set-controller" + namespace = var.arc_namespace + create_namespace = true +} + +resource "helm_release" "arc_runner_set" { + name = "v4-runner-set" + depends_on = [ + helm_release.arc + ] + chart = "oci://ghcr.io/actions/actions-runner-controller-charts/gha-runner-scale-set" + namespace = var.runner_namespace + create_namespace = true + + values = [ + templatefile("modules/google-arc-v4-container-cluster/arc-values.yaml", { + github_repo_url = var.github_repo_url + max_tpu_nodes = var.max_tpu_nodes + runner_image = var.runner_image + }) + ] +} diff --git a/infra/tpu-ci/modules/google-arc-v4-container-cluster/variables.tf b/infra/tpu-ci/modules/google-arc-v4-container-cluster/variables.tf new file mode 100644 index 00000000000..dbba690409f --- /dev/null +++ b/infra/tpu-ci/modules/google-arc-v4-container-cluster/variables.tf @@ -0,0 +1,51 @@ +variable "cluster_name" { + description = "Name of the Container Cluster containing the v4 node pool" + type = string +} + +variable "cpu_nodepool_name" { + description = "Name of the CPU Nodepool" + type = string +} + +variable "cpu_node_count" { + description = "Number of CPU nodes" + type = number +} + +variable "tpu_nodepool_name" { + description = "Name of the TPU Nodepool" + type = string +} + +variable "max_tpu_nodes" { + description = "Maximum number of TPU nodes and runners" + type = number +} + +variable "arc_namespace" { + description = "The namespace where ARC will reside" + default = "arc-systems" + type = string +} + +variable "runner_namespace" { + description = "The namespace where the ARC runners will reside" + default = "arc-runners" + type = string +} + +variable "github_repo_url" { + description = "The full URL of the repository which will be utilizing the self-hosted runners in ARC" + type = string +} + +variable "project_id" { + description = "The project ID" + type = string +} + +variable "runner_image" { + description = "The Docker image used in the self-hosted runner" + type = string +} diff --git a/infra/tpu-ci/versions.tf b/infra/tpu-ci/versions.tf new file mode 100644 index 00000000000..4da6455aafb --- /dev/null +++ b/infra/tpu-ci/versions.tf @@ -0,0 +1,14 @@ +terraform { + required_providers { + google = { + source = "hashicorp/google" + version = "5.6.0" + } + } + + required_version = ">= 0.14" + backend "gcs" { + bucket = "bzmarke-tfstate" + prefix = "terraform/state" + } +} diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index bd5edb6d02a..e06c9670d5f 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -35,30 +35,39 @@ xrt_versioned_builds = [ { accelerator = "tpu" python_version = "3.10" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0+xrt" }, { accelerator = "cuda" python_version = "3.10" cuda_version = "12.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0+xrt" }, ] # Built on push to specific tag. versioned_builds = [ + # Remove libtpu from PyPI builds { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "tpu" bundle_libtpu = "0" }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" + package_version = "2.1.0" + accelerator = "tpu" + python_version = "3.9" + bundle_libtpu = "0" + }, + { + git_tag = "v2.1.0" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "tpu" python_version = "3.10" @@ -66,7 +75,16 @@ versioned_builds = [ }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" + package_version = "2.1.0" + accelerator = "tpu" + python_version = "3.11" + bundle_libtpu = "0" + }, + # Bundle libtpu for Kaggle + { + git_tag = "v2.1.0" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0+libtpu" accelerator = "tpu" python_version = "3.10" @@ -84,26 +102,41 @@ versioned_builds = [ }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0", accelerator = "cuda" cuda_version = "12.0" }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "cuda" cuda_version = "11.8" }, { git_tag = "v2.1.0" - pytorch_git_rev = "v2.1.0-rc6" + pytorch_git_rev = "v2.1.0" + package_version = "2.1.0" + accelerator = "cuda" + cuda_version = "12.1" + }, + { + git_tag = "v2.1.0" + pytorch_git_rev = "v2.1.0" package_version = "2.1.0" accelerator = "cuda" cuda_version = "11.8" python_version = "3.10" }, + { + git_tag = "v2.1.0" + pytorch_git_rev = "v2.1.0" + package_version = "2.1.0" + accelerator = "cuda" + cuda_version = "12.1" + python_version = "3.10" + }, { git_tag = "v2.0.0" package_version = "2.0" diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff new file mode 100644 index 00000000000..47ee3fa0f0a --- /dev/null +++ b/openxla_patches/gpu_topk_rewriter.diff @@ -0,0 +1,184 @@ +diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc +index da872d962..1b7141055 100644 +--- a/xla/service/topk_rewriter.cc ++++ b/xla/service/topk_rewriter.cc +@@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { + return std::nullopt; + } + const int64_t sort_dim = sort->sort_dimension(); ++ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; ++ const bool has_batch = data->shape().rank() == 2; + + bool supported = true; + std::optional k; +@@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { + supported = false; + break; + } +- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { +- if (i != sort_dim && +- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { +- // Slicing along a non-sort dimension isn't supported. +- supported = false; +- break; +- } +- } +- if (!supported) { ++ if (has_batch && slice->slice_limits(batch_dim) != ++ slice->operand(0)->shape().dimensions(batch_dim)) { ++ // Slicing along the batch dimension isn't supported. ++ supported = false; + break; + } + if (k == std::nullopt) { +@@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( + HloSortInstruction* sort = DynCast(inst); + HloInstruction* data = sort->mutable_operand(0); + const PrimitiveType element_type = data->shape().element_type(); +- const Shape data_shape = data->shape(); + +- if (element_type != F32 && element_type != BF16) { ++ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || ++ (element_type != F32 && element_type != BF16)) { + continue; + } + +- // Sort dimension must be the first or last dimension. + const int64_t sort_dim = sort->sort_dimension(); +- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { +- continue; +- } ++ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; ++ const bool has_batch = data->shape().rank() == 2; + + // Profitability check. + if (!is_profitable_to_convert_(sort, *k)) { + continue; + } + +- HloInstruction* input = data; +- const bool has_batch = data_shape.rank() >= 2; +- const int64_t input_size = data_shape.dimensions(sort_dim); +- int64_t batch_size = 1; +- Shape topk_input_shape; +- +- if (has_batch) { +- // The TopK custom call expects either a 1d tensor or a 2d tensor with +- // the last dimension being the sort dimension. An input with rank > 2 +- // is reshaped into a 2d tensor by combining non-sort dimensions into a +- // single batch dimension. The original non-sort dimensions are +- // restored for the outputs with another reshape after the custom call. +- batch_size = +- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); +- topk_input_shape = +- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); +- +- if (data_shape.rank() > 2) { +- // Reshape to 2d. +- input = comp->AddInstruction(HloInstruction::CreateReshape( +- sort_dim == 0 +- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) +- : ShapeUtil::MakeShape(element_type, +- {batch_size, input_size}), +- input)); +- } +- +- if (sort_dim == 0) { +- // Transpose for the custom call when sorting the first dimension. +- input = comp->AddInstruction( +- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); +- } +- } else { +- topk_input_shape = data_shape; ++ const int64_t batch_size = ++ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; ++ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); ++ HloInstruction* input = sort->mutable_operand(0); ++ if (has_batch && sort_dim == 0) { ++ input = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, ++ {1, 0})); + } + + Shape topk_shape = +@@ -331,26 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + +- if (has_batch) { +- if (sort_dim == 0) { +- // Transpose back. +- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( +- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), +- value_gte, {1, 0})); +- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( +- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, +- {1, 0})); +- } +- if (data_shape.rank() > 2) { +- // Reshape back. +- std::vector shape_dim(data_shape.dimensions().begin(), +- data_shape.dimensions().end()); +- shape_dim[sort_dim] = k.value(); +- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( +- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); +- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( +- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); +- } ++ if (has_batch && sort_dim == 0) { ++ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), ++ value_gte, {1, 0})); ++ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, ++ {1, 0})); + } + + for (HloInstruction* user : sort->users()) { +diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc +index 36e723737..25ce150e0 100644 +--- a/xla/service/topk_rewriter_test.cc ++++ b/xla/service/topk_rewriter_test.cc +@@ -326,42 +326,6 @@ ENTRY cluster { + EXPECT_THAT(cc->custom_call_target(), "TopK"); + } + +-TEST_F(TopkRewriterTest, RewriteReshape) { +- const std::string hlo_string = R"( +-HloModule module +-)" + getComparator() + R"( +-ENTRY cluster { +- %arg_tuple.1 = f32[3,8,1234567] parameter(0) +- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 +- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), +- dimensions={2}, is_stable=true, to_apply=%compare +- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 +- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} +- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 +- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} +- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) +-})"; +- TF_ASSERT_OK_AND_ASSIGN(auto module, +- ParseAndReturnVerifiedModule(hlo_string)); +- TopkRewriter rewriter( +- [](const HloSortInstruction*, int64_t) { return true; }); +- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); +- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); +- EXPECT_TRUE(changed); +- EXPECT_THAT(module->entry_computation()->root_instruction(), +- GmockMatch(m::Tuple( +- m::Reshape(m::GetTupleElement( +- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), +- m::Reshape(m::GetTupleElement( +- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); +- const HloInstruction* cc = module->entry_computation() +- ->root_instruction() +- ->operand(0) +- ->operand(0) +- ->operand(0); +- EXPECT_THAT(cc->custom_call_target(), "TopK"); +-} +- + TEST_F(TopkRewriterTest, RewriteNoIota) { + const std::string hlo_string = R"( + HloModule module diff --git a/openxla_patches/pjrt_api_tsl_logging.diff b/openxla_patches/pjrt_api_tsl_logging.diff deleted file mode 100644 index 296bed91ad6..00000000000 --- a/openxla_patches/pjrt_api_tsl_logging.diff +++ /dev/null @@ -1,21 +0,0 @@ -# Fixes log spam when loading libtpu. We should fix this upstream. -diff --git a/xla/pjrt/pjrt_api.cc b/xla/pjrt/pjrt_api.cc -index 132cfaff0..887e842e0 100644 ---- a/xla/pjrt/pjrt_api.cc -+++ b/xla/pjrt/pjrt_api.cc -@@ -17,7 +17,6 @@ limitations under the License. - - #include - --#include "absl/log/log.h" - #include "absl/status/status.h" - #include "absl/strings/str_cat.h" - #include "xla/pjrt/c/pjrt_c_api.h" -@@ -33,6 +32,7 @@ limitations under the License. - #include "xla/pjrt/c/pjrt_c_api_helpers.h" - #include "xla/status.h" - #include "xla/statusor.h" -+#include "tsl/platform/logging.h" - #include "tsl/platform/errors.h" - - namespace pjrt { diff --git a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff deleted file mode 100644 index ee1ec00eced..00000000000 --- a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff +++ /dev/null @@ -1,76 +0,0 @@ -# Partial backport of 6308dba2903e78961ac4122f361bc91b09f36891. Remove in next -# pin update. -diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc -index ef0b6686c..c0341e81e 100644 ---- a/xla/pjrt/pjrt_c_api_client.cc -+++ b/xla/pjrt/pjrt_c_api_client.cc -@@ -1584,6 +1584,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { - return args.num_dynamic_dims > 0; - } - -+absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { -+ { -+ absl::MutexLock lock(&mu_); -+ if (!is_dynamic_dimension_.has_value()) { -+ absl::InlinedVector& is_dynamic_dimension_value = -+ is_dynamic_dimension_.emplace(); -+ is_dynamic_dimension_value.assign(dimensions().size(), false); -+ -+ PJRT_Buffer_DynamicDimensionIndices_Args args; -+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; -+ args.priv = nullptr; -+ args.buffer = buffer_.get(); -+ const PJRT_Api* api = pjrt_c_api(); -+ std::unique_ptr error( -+ api->PJRT_Buffer_DynamicDimensionIndices(&args), -+ pjrt::MakeErrorDeleter(api)); -+ if (error && pjrt::GetErrorCode(error.get(), api) == -+ PJRT_Error_Code_UNIMPLEMENTED) { -+ return *is_dynamic_dimension_; -+ } -+ for (int i = 0; i < args.num_dynamic_dims; ++i) { -+ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true; -+ } -+ } -+ } -+ return *is_dynamic_dimension_; -+} -+ - StatusOr> PjRtCApiBuffer::logical_dimensions() { - PJRT_Buffer_UnpaddedDimensions_Args args; - args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; -diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h -index 9c460f246..279608e60 100644 ---- a/xla/pjrt/pjrt_c_api_client.h -+++ b/xla/pjrt/pjrt_c_api_client.h -@@ -27,6 +27,7 @@ limitations under the License. - #include - - #include "absl/container/flat_hash_map.h" -+#include "absl/container/inlined_vector.h" - #include "absl/log/check.h" - #include "absl/log/log.h" - #include "absl/strings/string_view.h" -@@ -369,11 +370,7 @@ class PjRtCApiBuffer : public PjRtBuffer { - - bool has_dynamic_dimensions() const override; - -- absl::Span is_dynamic_dimension() const override { -- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. " -- << "Considering using has_dynamic_dimensions() or " -- "logical_dimensions() if applicable."; -- } -+ absl::Span is_dynamic_dimension() const override; - - StatusOr> logical_dimensions() override; - -@@ -455,6 +452,9 @@ class PjRtCApiBuffer : public PjRtBuffer { - std::shared_ptr::Promise> readiness_promise_; - // Set and cached the first time layout() is called. - mutable std::optional layout_; -+ // Set and cached the first time is_dynamic_dimension() is called. -+ mutable std::optional> -+ is_dynamic_dimension_; - // Used to synchronize concurrent setting of cached values. - mutable absl::Mutex mu_; - }; diff --git a/scripts/build_torch_wheels.sh b/scripts/build_torch_wheels.sh index 53560af530e..93ec524786d 100755 --- a/scripts/build_torch_wheels.sh +++ b/scripts/build_torch_wheels.sh @@ -267,7 +267,7 @@ function build_and_install_torch() { function build_and_install_torch_xla() { git submodule update --init --recursive if [ "${RELEASE_VERSION}" = "nightly" ]; then - export VERSIONED_XLA_BUILD=1 + export GIT_VERSIONED_XLA_BUILD=true else export TORCH_XLA_VERSION=${RELEASE_VERSION:1} # r0.5 -> 0.5 fi diff --git a/setup.py b/setup.py index 3a98db22c32..a8a04c4c286 100644 --- a/setup.py +++ b/setup.py @@ -10,8 +10,8 @@ # specify the version of PyTorch/XLA, rather than the hard-coded version # in this file; used when we're building binaries for distribution # -# VERSIONED_XLA_BUILD -# creates a versioned build +# GIT_VERSIONED_XLA_BUILD +# creates a git versioned build # # TORCH_XLA_PACKAGE_NAME # change the package name to something other than 'torch_xla' @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20230825' +_libtpu_version = '0.1.dev20231022' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' @@ -101,9 +101,9 @@ def get_git_head_sha(base_dir): def get_build_version(xla_git_sha): version = os.getenv('TORCH_XLA_VERSION', '2.2.0') - if _check_env_flag('VERSIONED_XLA_BUILD', default='0'): + if _check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'): try: - version += '+' + xla_git_sha[:7] + version += '+git' + xla_git_sha[:7] except Exception: pass return version @@ -244,9 +244,10 @@ def bazel_build(self, ext): bazel_argv = [ 'bazel', 'build', ext.bazel_target, - f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}", - '\n'.join(['--cxxopt=%s' % opt for opt in extra_compile_args]) + f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] + for opt in extra_compile_args: + bazel_argv.append("--cxxopt={}".format(opt)) # Debug build. if DEBUG: diff --git a/test/bench.py b/test/bench.py index 8d37aafab53..a8908eb9b5e 100644 --- a/test/bench.py +++ b/test/bench.py @@ -128,7 +128,7 @@ def run_benchmarks(args): args, benchs = parser.parse_known_args() args.benchs = benchs - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) run_benchmarks(args) diff --git a/test/cpp/BUILD b/test/cpp/BUILD index 6a343a1b3ac..c8aeb729d78 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -78,9 +78,9 @@ ptxla_cc_test( ":torch_xla_test", "//torch_xla/csrc/runtime:runtime", "//torch_xla/csrc/runtime:debug_macros", - "//torch_xla/csrc/runtime:multi_wait", - "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc:tensor", + "//torch_xla/csrc:thread_pool", + "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", "@xla//xla:shape_util", "@xla//xla/client:xla_builder", @@ -101,15 +101,16 @@ ptxla_cc_test( ], ) -ptxla_cc_test( - name = "test_xla_backend_intf", - srcs = ["test_xla_backend_intf.cpp"], - deps = [ - ":cpp_test_util", - "//torch_xla/csrc:tensor", - "@com_google_googletest//:gtest_main", - ], -) +# Disable this test since it is flaky on upstream +# ptxla_cc_test( +# name = "test_xla_backend_intf", +# srcs = ["test_xla_backend_intf.cpp"], +# deps = [ +# ":cpp_test_util", +# "//torch_xla/csrc:tensor", +# "@com_google_googletest//:gtest_main", +# ], +# ) ptxla_cc_test( name = "test_xla_sharding", diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 28f3d23d5c1..2b9b4198a18 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -307,7 +307,7 @@ std::vector Fetch( std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( - literal, TensorTypeFromXlaType(literal.shape().element_type()))); + literal, MaybeUpcastToHostTorchType(literal.shape().element_type()))); } return tensors; } diff --git a/test/cpp/cpp_test_util.h b/test/cpp/cpp_test_util.h index 256c9e4f02e..730ddf9f06f 100644 --- a/test/cpp/cpp_test_util.h +++ b/test/cpp/cpp_test_util.h @@ -12,6 +12,7 @@ #include "absl/types/span.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/tensor.h" diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 38be723950d..16bba61b195 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -98,7 +98,8 @@ elif [[ "$RUN_CPP_TESTS2" == "cpp_tests2" ]]; then "test_lazy" "test_replication" "test_tensor" - "test_xla_backend_intf" + # disable test_xla_backend_intf since it is flaky on upstream + #"test_xla_backend_intf" "test_xla_sharding") fi for name in "${test_names[@]}"; do diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d59b5b32360..e95651604d7 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -23,6 +23,16 @@ class AtenXlaTensorTest : public AtenXlaTensorTestBase {}; } // namespace +TEST_F(AtenXlaTensorTest, TestStorage) { + torch::Tensor a = torch::tensor({0.0}); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + XLATensorPtr xla_tensor_a = bridge::GetXlaTensor(xla_a); + EXPECT_EQ(xla_a.device(), xla_tensor_a->Storage().device()); + AllClose(a, xla_a); + }); +} + TEST_F(AtenXlaTensorTest, TestEmpty) { torch::Tensor a = torch::zeros({2, 2}, torch::TensorOptions(torch::kFloat)); ForEachDevice([&](const torch::Device& device) { @@ -129,8 +139,7 @@ TEST_F(AtenXlaTensorTest, TestFull) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::full", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestFullLike) { @@ -3369,6 +3378,22 @@ TEST_F(AtenXlaTensorTest, TestPrelu) { ExpectCounterChanged("xla::_prelu_kernel", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestPreluBackward) { + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::prelu(inputs[0], inputs[1]); + }; + torch::Tensor input = torch::rand( + {5, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor weight = torch::rand({3}, torch::TensorOptions(torch::kFloat)); + ForEachDevice([&](const torch::Device& device) { + TestBackward({input, weight}, device, testfn); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_prelu_kernel_backward", + cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestHardshrink) { torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat)); torch::Tensor output = torch::hardshrink(input); diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 9c4135f64c5..b8599c0e7d6 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -1512,7 +1512,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) { /*cudnn_enabled=*/false); }; torch::Tensor undef; - ForEachDevice({XlaDeviceType::GPU, XlaDeviceType::TPU}, + ForEachDevice({XlaDeviceType::CUDA, XlaDeviceType::TPU}, [&](const torch::Device& device) { TestBackward({input, undef_weight ? undef : weight, undef_weight ? undef : bias}, diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index 1e61a6fa05a..0a0d84d463a 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -956,6 +956,18 @@ TEST_F(AtenXlaTensorTest, TestSqueezeMultipleDims) { }); } +TEST_F(AtenXlaTensorTest, TestSqueezeDimWithNegativeOne) { + torch::Tensor input = + torch::rand({2, 1, 3, 1}, torch::TensorOptions(torch::kFloat)); + std::vector dims = {-1}; + torch::Tensor output = torch::squeeze(input, dims); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::squeeze(xla_input, dims); + AllClose(output, xla_output); + }); +} + TEST_F(AtenXlaTensorTest, TestSqueezeOneInPlace) { int rank = 4; for (int dim = -rank; dim < rank; ++dim) { diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index d2e3e284f5b..d7eb32619c4 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -873,7 +873,7 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { XlaDeviceType hw_type = static_cast(bridge::GetDefaultDevice()->type()); - if (hw_type != XlaDeviceType::GPU && hw_type != XlaDeviceType::CPU) { + if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) { return; } torch::Tensor growth_tracker = diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 08b039b9e5f..39fbb7201b0 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -3,15 +3,15 @@ #include +#include "absl/synchronization/blocking_counter.h" #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "xla/client/xla_builder.h" #include "xla/shape_util.h" @@ -57,7 +57,7 @@ void TestSingleReplication( std::vector> results(device_strings.size()); - torch_xla::runtime::util::MultiWait mwait(device_strings.size()); + absl::BlockingCounter counter(device_strings.size()); torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { @@ -68,11 +68,11 @@ void TestSingleReplication( torch_xla::runtime::ComputationClient::Data>( tensors_data[i])}, device_strings[i], exec_options); + counter.DecrementCount(); }; - torch_xla::runtime::env::ScheduleIoClosure( - mwait.Completer(std::move(executor))); + torch_xla::thread::Schedule(std::move(executor)); } - mwait.Wait(); + counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { auto literals = @@ -94,7 +94,7 @@ class ReplicationTest : public AtenXlaTensorTestBase {}; TEST_F(ReplicationTest, TestNSingleReplication) { WithAllDevices( - {XlaDeviceType::TPU, XlaDeviceType::GPU}, + {XlaDeviceType::TPU, XlaDeviceType::CUDA}, [&](const std::vector& devices, const std::vector& all_devices) { TestSingleReplication(devices, all_devices); diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index a15e7f1aca3..84fe53a5cdd 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -43,7 +43,9 @@ def forward(self, x): @unittest.skipIf( - not xm.get_xla_supported_devices("GPU") and + # Currently a change break this test on CUDA. Another change is trying to + # roll back it. Will uncomment the line below once it is rolled back. + # not xm.get_xla_supported_devices("CUDA") and not xm.get_xla_supported_devices("TPU"), f"The tests fail on CPU. See https://github.com/pytorch/xla/issues/4298 for more detail." ) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 778d77591e4..a419f9335db 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -78,6 +78,19 @@ def get_random_inputs(self): return (torch.randn(10), torch.randn(10)) +class UpsampleModule(nn.Module): + + def __init__(self): + super().__init__() + self.upsample = nn.Upsample(scale_factor=2) + + def forward(self, x): + return self.upsample(x) + + def get_random_inputs(self): + return (torch.randn((1, 1, 5)),) + + def allclose(expected, actual): def unwrap(cont): @@ -179,6 +192,7 @@ def test_wrapper(self): model = model.to(device=xla_dev) inputs = tuple(inp.to(device=xla_dev) for inp in inputs) + inputs = tuple(inp.requires_grad_() for inp in inputs) # do baseline baseline_model = copy.deepcopy(model) @@ -206,6 +220,25 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_training_linear = make_training_test(LinearModule) test_training_maxpool = make_training_test(MaxPoolModule) + test_training_upsample = make_training_test(UpsampleModule) + + def test_non_tensor_args_for_partition(self): + + class Emb(torch.nn.Embedding): + + def __init__(self): + super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) + + device = xm.xla_device() + module = Emb() + module.to(device) + + @torch.compile(backend="openxla_eval") + def foo(x): + return module(x) + + x = torch.randint(0, 10, (10,), device=device) + foo(x) if __name__ == "__main__": diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index f84cc30ec9e..7b359311c8f 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...): def test_ddp_init(self): pjrt.run_multiprocess(self._ddp_init) - @absltest.skipIf(xr.device_type() == 'GPU', + @absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "GPU device is not supported by pjrt.spawn_threads") def test_ddp_init_threaded(self): pjrt.spawn_threads(self._ddp_init) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index 8e500ea4ef0..8cb930714e0 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -16,7 +16,7 @@ class TestExperimentalPjrt(parameterized.TestCase): def setUp(self): xr.set_device_type('CPU') - @parameterized.parameters(('CPU', 'CPU'), ('GPU', 'GPU'), ('TPU', 'TPU'), + @parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU'), ('TPU_C_API', 'TPU'), ('TPU_LEGACY', 'TPU')) def test_device_type(self, pjrt_device, expected): with mock.patch.dict(os.environ, {'PJRT_DEVICE': pjrt_device}, clear=True): @@ -61,7 +61,7 @@ def test_xla_device_error(self): }, True), ('gpu_num_devives', { 'GPU_NUM_DEVICES': '4' }, True), ('pjrt_gpu', { - 'PJRT_DEVICE': 'GPU', + 'PJRT_DEVICE': 'CUDA', 'GPU_NUM_DEVICES': '4' }, True)) def test_pjrt_default_device(self, env_vars, expect_using_pjrt): @@ -77,7 +77,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt): xr.using_pjrt() if expect_using_pjrt: - self.assertIn(xr.device_type(), ['CPU', 'GPU', 'TPU']) + self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'ROCM', 'GPU']) else: self.assertIsNone(xr.device_type()) diff --git a/test/pjrt/test_runtime_gpu.py b/test/pjrt/test_runtime_gpu.py index d82144b2c1a..29cf6bce467 100644 --- a/test/pjrt/test_runtime_gpu.py +++ b/test/pjrt/test_runtime_gpu.py @@ -17,12 +17,12 @@ from absl.testing import absltest, parameterized -@unittest.skipIf(xr.device_type() != 'GPU', +@unittest.skipIf(xr.device_type() not in ('GPU', 'CUDA', 'ROCM'), f"GPU tests should only run on GPU devices.") class TestExperimentalPjrtGpu(parameterized.TestCase): def setUp(self): - xr.set_device_type('GPU') + xr.set_device_type('CUDA') os.environ.update({ xenv.PJRT_GPU_ASYNC_CLIENT: 'true', @@ -178,7 +178,6 @@ def _reduce_scatter(pin_layout): return out.cpu().numpy() # 2023-08-02 04:16:36.520884: F external/xla/xla/service/layout_assignment.cc:157] Check failed: ShapeUtil::Compatible(shape_layout.shape(), instruction->operand(operand_no)->shape()) f32[1]{0} is not compatible with f32[2]{0} (for operand 0 of instruction %reduce-scatter.10 = f32[1]{0} reduce-scatter(f32[2]{0} %add.5), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.6) - @unittest.skip("Failed with known error.") @parameterized.named_parameters(('pinned', True), ('unpinned', False)) def test_reduce_scatter(self, pin_layout): results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout) diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index 98c8aa47fef..9a3fce79499 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -10,9 +10,13 @@ class TestTorchrun(absltest.TestCase): - def test_all_gather(self): + def setUp(self): dist.init_process_group('xla', init_method='xla://') + def tearDown(self) -> None: + dist.destroy_process_group() + + def test_all_gather(self): dist_world_size = xu.getenv_as('WORLD_SIZE', int) devices_per_thread = xr.addressable_device_count() @@ -29,6 +33,47 @@ def test_all_gather(self): expected = torch.arange(0, expected_world_size, step=1, dtype=torch.float32) torch.testing.assert_close(result.cpu(), expected) + def test_all_reduce(self): + # The test is inspired by https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce + dist_world_size = xu.getenv_as('WORLD_SIZE', int) + devices_per_thread = xr.addressable_device_count() + world_size = dist_world_size * devices_per_thread + + # If world_size=2, then the `tensors` below will be [[1, 2], [3, 4]]. + # The `expected` will be [4, 6]. + tensors = [ + torch.arange(2, dtype=torch.int64) + 1 + 2 * r + for r in range(world_size) + ] + expected = sum(tensors) + + xla_tensor = torch.arange( + 2, dtype=torch.int64, device=xm.xla_device()) + 1 + 2 * dist.get_rank() + dist.all_reduce(xla_tensor, op=dist.ReduceOp.SUM) + xm.mark_step() + + torch.testing.assert_close(xla_tensor.cpu(), expected) + + def test_reduce_scatter(self): + # The test is inspired by https://pytorch.org/docs/stable/distributed.html#torch.distributed.reduce_scatter + dist_world_size = xu.getenv_as('WORLD_SIZE', int) + devices_per_thread = xr.addressable_device_count() + world_size = dist_world_size * devices_per_thread + # If world_size=2, then `tensor` will be tensor([0, 2, 4, 6]) + # `expected` will be [0, 2] on rank 0 and [4, 6] on rank 1. + tensor = world_size * torch.arange( + world_size * world_size, dtype=torch.int64) + expected = torch.split(tensor, world_size)[dist.get_rank()] + + tensor_out = torch.zeros( + world_size, dtype=torch.int64, device=xm.xla_device()) + tensor_in = torch.arange( + world_size * world_size, dtype=torch.int64, device=xm.xla_device()) + dist.reduce_scatter(tensor_out, [tensor_in], op=dist.ReduceOp.SUM) + xm.mark_step() + + torch.testing.assert_close(tensor_out.cpu(), expected) + if __name__ == '__main__': if not dist.is_torchelastic_launched(): diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 1c77e85b9f6..6835ebe7993 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -519,7 +519,7 @@ def union_of_disabled_tests(sets): DISABLED_TORCH_TESTS = { 'TPU': prepare_match_set(DISABLED_TORCH_TESTS_TPU), 'CPU': prepare_match_set(DISABLED_TORCH_TESTS_CPU), - 'GPU': prepare_match_set(DISABLED_TORCH_TESTS_GPU), + 'CUDA': prepare_match_set(DISABLED_TORCH_TESTS_GPU), } diff --git a/test/run_tests.sh b/test/run_tests.sh index 9b13aa4494f..31b30fa95ee 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -56,7 +56,7 @@ function run_coverage { function run_test { echo "Running in PjRt runtime: $@" if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then - PJRT_DEVICE=GPU run_coverage "$@" + PJRT_DEVICE=CUDA run_coverage "$@" else # TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue. PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$@" @@ -108,6 +108,11 @@ function run_save_tensor_hlo { XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@" } +function run_pt_xla_debug { + echo "Running in save tensor file mode: $@" + PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" +} + function run_stablehlo_compile { echo "Running in StableHlo Compile mode: $@" XLA_STABLEHLO_COMPILE=1 run_test "$@" @@ -118,6 +123,14 @@ function run_xla_backend_mp { MASTER_ADDR=localhost MASTER_PORT=6000 run_test "$@" } +function run_torchrun { + if [ -x "$(command -v nvidia-smi)" ] && [ "$XLA_CUDA" != "0" ]; then + echo "Running torchrun test for GPU $@" + num_devices=$(nvidia-smi --list-gpus | wc -l) + PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node $num_devices $@ + fi +} + function run_torch_op_tests { run_dynamic "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA run_test_without_functionalization "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA @@ -148,6 +161,7 @@ function run_xla_op_tests1 { run_test "$CDIR/test_grad_checkpoint.py" run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_pt_xla_debug "$CDIR/test_pt_xla_debug.py" run_test "$CDIR/test_async_closures.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/pjrt/test_runtime.py" @@ -193,6 +207,7 @@ function run_xla_op_tests3 { run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" + run_torchrun "$CDIR/pjrt/test_torchrun.py" } ####################################################################################### diff --git a/test/spmd/test_dtensor_integration.py b/test/spmd/test_dtensor_integration.py new file mode 100644 index 00000000000..552e698d352 --- /dev/null +++ b/test/spmd/test_dtensor_integration.py @@ -0,0 +1,81 @@ +import os +import sys + +import torch +from torch import nn +import torch.optim as optim +from torch.distributed._tensor import DeviceMesh, Shard +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import xla_distribute_tensor + +import unittest + +import test_xla_sharding_base + + +class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + super().setUpClass() + + def test_xla_distribute_tensor(self): + device_count = xr.global_runtime_device_count() + device_mesh = DeviceMesh("xla", list(range(device_count))) + shard_spec = [Shard(0)] + + for requires_grad in [True, False]: + tensor_to_shard = torch.randn( + 3 * device_count, + 3, + requires_grad=requires_grad, + device=xm.xla_device()) + dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh, + shard_spec) + # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor + assert type(dist_tensor).__name__ == "XLAShardedTensor" + assert len(dist_tensor.sharding_spec) > 0 + + global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] + self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3])) + local_tensor = dist_tensor.local_shards[0].data + self.assertEqual(local_tensor.size(), torch.Size([3, 3])) + if requires_grad: + self.assertTrue(dist_tensor.global_tensor.requires_grad) + self.assertTrue(dist_tensor.is_leaf) + + def test_optimizer_step_with_sharding(self): + # Use simple linear model to test model parameter sharding + model = self.SimpleLinear().to(xm.xla_device()) + + # Running the same mark_sharding test with xla_distribute_tensor instead + device_count = xr.global_runtime_device_count() + device_mesh = DeviceMesh("xla", list(range(device_count))) + shard_spec = [Shard(0)] + xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec) + sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight) + + model.train() + optimizer = optim.SGD(model.parameters(), lr=0.1) + data = torch.randn(128, 128).to(xm.xla_device()) + target = torch.zeros(128).to(xm.xla_device()) + loss_fn = nn.CrossEntropyLoss() + for i in range(3): + optimizer.zero_grad() + output = model(data) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + xm.mark_step() + # Sharding is persisted across mark_step calls, and test if the sharded computation + # can repeat more than once without crashing. + self.assertEqual(sharding_spec, + torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 22cd2980413..807a518d95b 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -6,7 +6,7 @@ import torch_xla import torch_xla.runtime as xr import torch_xla.core.xla_model as xm -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs import torch_xla.debug.metrics as met import unittest @@ -15,7 +15,7 @@ class SimpleLinear(nn.Module): - def __init__(self): + def __init__(self, mesh=None): super(SimpleLinear, self).__init__() self.fc1 = nn.Linear(128, 128) self.relu = nn.ReLU() @@ -23,8 +23,14 @@ def __init__(self): # Add an additional 1x1 layer at the end to ensure the final layer # is not sharded. self.fc3 = nn.Linear(1, 1) + # If mesh is not none, we'll do a mark sharding inside the forward function + # to ensure dynamo can recognize and trace it in a torch compile. + self.mesh = mesh def forward(self, x): + if self.mesh and 'xla' in str(self.fc2.weight.device): + xs.mark_sharding( + self.fc2.weight, self.mesh, (1, 0), use_dynamo_custom_op=True) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -171,6 +177,50 @@ def test_dynamo_input_sharding_threashold(self): else: del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] + def test_dynamo_spmd_mark_sharding_outside_of_compile(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(1, 128, device=device) + xs.mark_sharding( + linear.fc2.weight, + self._get_mesh((1, self.n_devices)), (1, 0), + use_dynamo_custom_op=True) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # Ensure that another run with same input does not trigger additional compilation + compile_count = met.metric_data('CompileTime')[0] + dynamo_res = dynamo_linear(xla_x) + self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + + def test_mark_sharding_inside_compile(self): + met.clear_counters() + device = xm.xla_device() + mesh = self._get_mesh((1, self.n_devices)) + + # Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op + # variant of mark_sharding inside the forward function. + linear = SimpleLinear(mesh=mesh).to(device) + linear.eval() + + xla_x = torch.randn(1, 128, device=device) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + # Ensure that another run with same input does not trigger additional compilation + compile_count = met.metric_data('CompileTime')[0] + dynamo_res = dynamo_linear(xla_x) + self.assertEqual(met.metric_data('CompileTime')[0], compile_count) + if __name__ == '__main__': test = unittest.main() diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index 73323eddcc0..3ea2b2302b0 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -10,7 +10,7 @@ import torch_xla import torch_xla.runtime as xr import torch_xla.core.xla_model as xm -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs import test_xla_sharding_base diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 2461658c801..7d472da83de 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -84,7 +84,7 @@ import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp import torch_xla.test.test_utils as test_utils -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs DEFAULT_KWARGS = dict( batch_size=128, @@ -372,7 +372,7 @@ def test_loop_fn(loader, epoch): if FLAGS.profile: server = xp.start_server(FLAGS.profiler_port) - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index ad5294e5cfe..e08f361c42a 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -7,10 +7,10 @@ import torch_xla.runtime as xr import torch_xla.debug.profiler as xp import torch_xla.distributed.parallel_loader as pl -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs import torch_xla.utils.checkpoint as checkpoint import torch_xla.utils.utils as xu -from torch_xla.experimental.xla_sharding import Mesh +from torch_xla.distributed.spmd import Mesh import torch.optim as optim from torch import nn diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 276571e5979..55fd7d9c155 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,25 +1,42 @@ +import functools import os +import signal import sys import tempfile -import unittest import test_xla_sharding_base +import threading +import time +import unittest import torch import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs from torch.distributed.checkpoint.default_planner import ( create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) +# Wrapper to manage a temporary directory for the wrapped test +def run_with_tmpdir(f): + + @functools.wraps(f) + def run(*args, **kwargs): + with tempfile.TemporaryDirectory() as tmpdir: + kwargs.setdefault('tmpdir', tmpdir) + f(*args, **kwargs) + + return run + + class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): @classmethod @@ -256,6 +273,20 @@ def test_save_state_dict_with_cpu_shards(self): self.assertTrue( isinstance(planner.sharded_state_dict['fc1.weight'], _CpuShards)) + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for sharded test") + def test_cpu_state_dict_flattening(self): + # In the case of a nested state_dict with fully sharded parameters, + # _CpuShards should be treated as terminal nodes. + t = torch.randn(128, 128).to(xm.xla_device()) + mesh = self._get_mesh((self.n_devices, 1)) + xs.mark_sharding(t, mesh, (0, 1)) + state_dict = _sharded_cpu_state_dict({'model': {'weight': t}}) + planner = SPMDSavePlanner() + planner.set_up_planner(state_dict, True) + # model.weight should be flattened and tracked in the sharded state dict. + self.assertCountEqual(planner.sharded_state_dict, ["model.weight"]) + def test_local_save_plan(self): def _write_item_assertions(plan, n_devices, parameter_count): @@ -319,6 +350,218 @@ def test_sharded_cpu_state_dict(self): self.assertTrue(param.device == torch.device("cpu")) +class CheckpointManagerTest(DistributedCheckpointTestBase): + + def setUp(self): + super().setUp() + # Initialize the a minimal process group + dist.init_process_group( + backend='gloo', + init_method='tcp://localhost:8932', + world_size=1, + rank=0) + torch_xla._XLAC._ensure_xla_coordinator_initialized( + global_rank=0, world_size=1, master_addr="localhost") + + def tearDown(self): + super().tearDown() + # Destroy the CPU process group after the test + dist.destroy_process_group() + + @run_with_tmpdir + def test_manager_checkpointing(self, tmpdir): + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) + state_dict = self._get_sharded_model().state_dict() + + # Take a checkpoint on step 0 + self.assertTrue(chkpt_mgr.save(0, state_dict)) + + # Load the checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + + @run_with_tmpdir + def test_manager_step_tracking(self, tmpdir): + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save(step, state_dict)) + saved.add(step) + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + @run_with_tmpdir + def test_manager_max_to_keep(self, tmpdir): + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, max_to_keep=2, chkpt_on_preemption=False) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20}) + + # The oldest checkpoint should be erased + self.assertTrue(chkpt_mgr.save(30, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20}) + + # The oldest is selected by creation timestamp, not step + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10}) + + # The deletion order should persist across executions + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, max_to_keep=2, chkpt_on_preemption=False) + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) + + @run_with_tmpdir + def test_manager_async(self, tmpdir): + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) + state_dict = self._get_sharded_model().state_dict() + + # Patch the manager's save method to block until this thread signals. + cond = threading.Condition() + old_save = chkpt_mgr._save + + def patched_save(*args, **kwargs): + with cond: + cond.wait() + old_save(*args, **kwargs) + + with unittest.mock.patch.object(chkpt_mgr, '_save', patched_save): + chkpt_mgr.save_async(10, state_dict) + + # No new steps should be tracked immediately after calling save_async + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Trigger the actual checkpoint in the background thread and wait for + # completion. + with cond: + cond.notify() + chkpt_mgr.join() + + # The manager should track all steps which were asynchronously saved. + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + @run_with_tmpdir + def test_manager_async_step_tracking(self, tmpdir): + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) + state_dict = self._get_sharded_model().state_dict() + + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save_async(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save_async(step, state_dict)) + saved.add(step) + + # Join to allow pending async checkpoints to complete + chkpt_mgr.join() + + # The manager should track all steps which were asynchronously saved. + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + # Load a checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + + @unittest.skipUnless(xr.device_type() == 'TPU', + 'TPU required for worker IP discovery') + @unittest.mock.patch('torch_xla._internal.tpu.get_worker_ips') + def test_master_ip_discovery(self, patched_get_worker_ips): + # A basic test to verify the SPMD codepath returns the correct IP. Two IPs + # are needed to avoid the short-circuit return of localhost. + patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2'] + self.assertTrue(xr.get_master_ip(), '10.0.0.1') + + def test_preemption_sync_manager(self): + try: + torch_xla._XLAC._activate_preemption_sync_manager() + sync_point_reached = torch_xla._XLAC._sync_point_reached + + # No sync point for the first several steps + sigterm_step = 10 + for step in range(sigterm_step): + self.assertFalse(sync_point_reached(step)) + + # Send a SIGTERM to the current process to trigger a sync point + os.kill(os.getpid(), signal.SIGTERM) + + # Allow the signal to be processed. The PreemptionSyncManager must receive + # the SIGTERM, which happens asynchronously, and the state must be + # propagated through the distributed runtime. Eventually, + # sync_point_reached will return True. + success = False + for attempt in range(10): + success = sync_point_reached(sigterm_step + attempt) + if success: + break + time.sleep(1) + self.assertTrue(success, "Sync point was never reached after SIGTERM") + finally: + # Scope the PreemptionSyncManager to the lifespan of the test. + torch_xla._XLAC._deactivate_preemption_sync_manager() + + @unittest.skipUnless(xr.device_type() == 'TPU', + 'TPU required for worker IP discovery') + @run_with_tmpdir + def test_auto_checkpoint(self, tmpdir): + # Create a checkpoint manager with a long save interval + chkpt_mgr = CheckpointManager(tmpdir, save_interval=100) + state_dict = self._get_sharded_model().state_dict() + + preemption_step = 10 + # Skip step 0 so the manager will track no checkpoints before preemption + for step in range(1, preemption_step): + self.assertFalse(chkpt_mgr.save(step, state_dict)) + + with unittest.mock.patch('torch_xla._XLAC._sync_point_reached', + lambda x: True): + self.assertTrue(chkpt_mgr.save(preemption_step, state_dict)) + self.assertTrue(chkpt_mgr.reached_preemption(step)) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index ce2cae18dd6..db303302e09 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -15,8 +15,8 @@ import torch_xla.runtime as xr import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met -import torch_xla.experimental.xla_sharding as xs -from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import XLAShardedTensor import test_xla_sharding_base import torch_xla.core.xla_env_vars as xenv @@ -42,6 +42,20 @@ def test_xla_sharded_tensor(self): # TODO(244003536) add more tests for XLAShardedTensror. self.assertTrue(isinstance(xst1, XLAShardedTensor)) + def test_xla_sharded_tensor_repr(self): + xt = torch.randn(128, 128).to(xm.xla_device()) + model = self.SimpleLinear().to(xm.xla_device()) + + mesh = self._get_mesh((1, self.n_devices)) + partition_spec = (0, 1) + xst = xs.mark_sharding(xt, mesh, partition_spec) + self.assertTrue(isinstance(xst, XLAShardedTensor)) + + xt_output = model(xt) + self.assertTrue('XLAShardedTensor' not in str(xt_output)) + xst_output = model(xst) + self.assertTrue('XLAShardedTensor' in str(xst_output)) + def test_sharded_tensor_debug_info(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], @@ -900,6 +914,139 @@ def test_op_sharding_cache(self): xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) + def test_from_cpu_shards_replicated(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (None,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.arange(4)] * self.n_devices + + # No shape should result in the shape of a single shard. + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # Specify a valid shape for the global tensor + global_tensor = from_cpu_shards(shards, op_sharding, shards[0].shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + # All invalid shapes should raise + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((5,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((3,))) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((2, 2))) + + def test_from_cpu_shards_tiled(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an OpSharding with all devices on a single axis + mesh = self._get_mesh((self.n_devices,)) + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + shards = [torch.LongTensor([i]) for i in range(self.n_devices)] + + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue( + torch.allclose(global_tensor.cpu(), torch.arange(self.n_devices))) + + # Test incorrect number of shards + with self.assertRaises(RuntimeError): + from_cpu_shards(shards[:-1], op_sharding) + + # Test an invalid global shape - too many values. + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((self.n_devices * 2,))) + + # Test an invalid global shape - incorrect rank + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1, self.n_devices))) + + # Test a valid global shape - restrict the number of meaningful values + # to 1, treating the rest as padding. + global_tensor = from_cpu_shards(shards, op_sharding, torch.Size((1,))) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(1))) + + def test_from_cpu_shards_2d(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + # Create an appropriate 2D mesh for the number of devices + if self.n_devices >= 4: + mesh_shape = (self.n_devices // 2, 2) + else: + mesh_shape = (1, self.n_devices) + mesh_2d = self._get_mesh(mesh_shape) + + # Replicated sharding + shards = [torch.LongTensor([self.n_devices])] * self.n_devices + partition_spec = (None, None) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), shards[0])) + + if self.n_devices > 1: + # Tiled sharding + shards = [torch.LongTensor([[i]]) for i in range(self.n_devices)] + partition_spec = (0, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + expected = torch.arange(self.n_devices).reshape(*mesh_shape) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Partially replicated sharding + shards = [torch.LongTensor([[i]]) for i in range(2)] * ( + self.n_devices // 2) + partition_spec = (None, 1) + op_sharding = mesh_2d.get_op_sharding(partition_spec) + global_tensor = from_cpu_shards(shards, op_sharding) + # Partial replication along the 0th axis represents a global tensor + # of torch.Tensor([[0, 1]]). + expected = torch.arange(2).reshape(1, 2) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + def test_from_cpu_shards_global_shape(self): + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + + mesh = self._get_mesh((self.n_devices,)) + numel = self.n_devices**2 + # The global tensor is torch.arange(numel). + shards = [ + torch.arange(self.n_devices) + (i * self.n_devices) + for i in range(self.n_devices) + ] + partition_spec = (0,) + op_sharding = mesh.get_op_sharding(partition_spec) + + # No global shape specified will include all data from the shards + global_tensor = from_cpu_shards(shards, op_sharding) + self.assertTrue(torch.allclose(global_tensor.cpu(), torch.arange(numel))) + + # Too large of a global shape will error out + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((numel + 1,))) + + if self.n_devices > 1: + # When the global tensor has fewer elements than the sum of its shards, + # there are two cases: + + # Case 1: If the global shape is within n_devices of numel, the excess + # data is treated as padding and ignored. + for delta in range(self.n_devices): + size = torch.Size((numel - delta,)) + global_tensor = from_cpu_shards(shards, op_sharding, size) + expected = torch.arange(size[0]) + self.assertTrue(torch.allclose(global_tensor.cpu(), expected)) + + # Case 2: Otherwise, it is not possible to have that much padding in a + # sharded tensor, and the shards are incompatible with the shape. + with self.assertRaises(RuntimeError): + shape = torch.Size((numel - self.n_devices,)) + from_cpu_shards(shards, op_sharding, shape) + with self.assertRaises(RuntimeError): + from_cpu_shards(shards, op_sharding, torch.Size((1,))) + if __name__ == '__main__': test = unittest.main() diff --git a/test/spmd/test_xla_sharding_base.py b/test/spmd/test_xla_sharding_base.py index 4b83368d380..57cbfe2a076 100644 --- a/test/spmd/test_xla_sharding_base.py +++ b/test/spmd/test_xla_sharding_base.py @@ -3,14 +3,14 @@ from torch import nn import torch_xla.core.xla_model as xm -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr import torch_xla.core.xla_env_vars as xenv import torch_xla.utils.utils as xu @unittest.skipIf(not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) == "GPU", + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM'), f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") class XlaShardingTest(unittest.TestCase): diff --git a/test/spmd/test_xla_sharding_hlo.py b/test/spmd/test_xla_sharding_hlo.py index 3a39a906261..723d1c71fd3 100644 --- a/test/spmd/test_xla_sharding_hlo.py +++ b/test/spmd/test_xla_sharding_hlo.py @@ -9,7 +9,7 @@ import torch_xla import torch_xla.runtime as xr import torch_xla.core.xla_model as xm -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs import test_xla_sharding_base diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 8ea4db3e051..6fb12f916d2 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -3,7 +3,9 @@ import sys import torch +import torch.distributed as dist import torch_xla +import torch_xla.distributed.xla_backend import torch_xla.core.xla_model as xm from torch_xla import runtime as xr from torch_xla.amp import autocast @@ -120,7 +122,7 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'], + @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU', 'CUDA', 'ROCM'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): device = xm.xla_device() @@ -132,6 +134,19 @@ def test_xla_autocast_api(self): self.assertTrue(t3.dtype == expected_dtype) +class BasicDistributedTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + return super().setUpClass() + + def test_xla_backend(self): + # XLA backend is not supported with SPMD + with self.assertRaises(AssertionError): + dist.init_process_group('xla', init_method='xla://') + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index ac304e7285d..d58797eb5ff 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -9,7 +9,7 @@ import torch_xla.runtime as xr import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs import test_xla_sharding_base diff --git a/test/stablehlo/test_exports.py b/test/stablehlo/test_exports.py new file mode 100644 index 00000000000..ba99319c6b2 --- /dev/null +++ b/test/stablehlo/test_exports.py @@ -0,0 +1,32 @@ +import unittest +import torch +import torch.nn.functional as F +from torch_xla.stablehlo import exported_program_to_stablehlo + + +class Interpolate(torch.nn.Module): + + def forward(self, masks: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(500, 500), + mode="bilinear", + align_corners=False, + ) + return masks + + +class ExportTest(unittest.TestCase): + + def test_interpolate(self): + + arg = (torch.randn(3, 3, 200, 200),) + model = Interpolate() + + ans = model(*arg) + + with torch.no_grad(): + exported = torch._export.export(model, arg) + shlo = exported_program_to_stablehlo(exported) + ans2 = shlo(*arg).cpu().to(torch.float32) + self.assertTrue(torch.allclose(ans, ans2, atol=1e-5)) diff --git a/test/stablehlo/test_saved_model.py b/test/stablehlo/test_saved_model.py index 75c10c299b6..7c510ed3637 100644 --- a/test/stablehlo/test_saved_model.py +++ b/test/stablehlo/test_saved_model.py @@ -39,6 +39,24 @@ def test_resnet18_save_load(self): output2 = torch.tensor(res.numpy()) self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + def test_unused_param(self): + + class M(torch.nn.Module): + + def forward(self, a, b): + return torch.sin(b) + + model = M() + data = (torch.randn(4, 3, 224, 224), torch.randn(1, 100)) + output = model(*data) + + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model(model, data, tempdir) + loaded_m = tf.saved_model.load(tempdir) + res = loaded_m.f(data[0].detach().numpy(), data[1].detach().numpy())[0] + output2 = torch.tensor(res.numpy()) + self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + if __name__ == '__main__': test = unittest.main() diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py new file mode 100644 index 00000000000..a4223b799aa --- /dev/null +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -0,0 +1,58 @@ +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import exported_program_to_stablehlo + +# Note: Unbounded dynamism is under development. It works with unmerged +# XLA changes. Experimental XLA branch: https://github.com/lsy323/openxla-xla/tree/lsiyuan/sandeep-dynamism-rebased + +device = xm.xla_device() + + +class UnboundedDynamismExportTest(unittest.TestCase): + + def test_simply_add(self): + a = torch.tensor([[1, 2], [2, 4]], device=device) + torch_xla._XLAC._xla_mark_dynamic(a, 0) + b = torch.tensor([[1, 2], [2, 4]], device=device) + torch_xla._XLAC._xla_mark_dynamic(b, 0) + c = a * b + hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c]) + self.assertTrue( + "(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content) + + def test_export_dynamism(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y + + example_args = (torch.tensor([[1, 2], [2, 4]], device=device), + torch.tensor([[1, 2], [2, 4]], device=device)) + constraints = [ + # First dimension of each input is a dynamic batch size + torch.export.dynamic_dim(example_args[0], 0), + torch.export.dynamic_dim(example_args[1], 0), + # The dynamic batch size between the inputs are equal + torch.export.dynamic_dim(example_args[0], + 0) == torch.export.dynamic_dim( + example_args[1], 0), + ] + ep = torch.export.export(M(), args=example_args, constraints=constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text("forward") + self.assertTrue( + "(%arg0: tensor, %arg1: tensor) -> tensor" in + shlo_text) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_autocast.py b/test/test_autocast.py index 9caa3017ea8..edbd834b61b 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -341,7 +341,8 @@ def compare(first, second): self.assertFalse(self.is_autocast_enabled()) -@unittest.skipIf(not xm.get_xla_supported_devices("GPU"), f"GPU autocast test.") +@unittest.skipIf(not xm.get_xla_supported_devices("CUDA"), + f"CUDA autocast test.") class TestAutocastCuda(TestAutocastBase): def setUp(self): diff --git a/test/test_ddp.py b/test/test_ddp.py index 2389cc51f0d..25e53790cc5 100644 --- a/test/test_ddp.py +++ b/test/test_ddp.py @@ -16,7 +16,7 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. device = xm.xla_device() - if xm.xla_device_hw(device) not in ('GPU', 'TPU'): + if xm.xla_device_hw(device) not in ('GPU', 'TPU', 'CUDA', 'ROCM'): print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 5bd85bb6b94..b14fb769bc0 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -31,10 +31,10 @@ def forward(self, x): hidden2 = self.fc2(x) return hidden1, hidden2 - @unittest.skipIf( - xr.device_type() == 'GPU', - "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" - ) + @unittest.skipIf(xr.device_type() in ( + 'GPU', 'ROCM', 'CUDA' + ), "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" + ) def test(self): dev = xm.xla_device() input = torch.zeros([16, 16], device=dev) @@ -50,12 +50,12 @@ def test(self): def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) else: print( - 'Default device {} is not a TPU or GPU device'.format(device), + 'Default device {} is not a TPU or CUDA device'.format(device), file=sys.stderr) diff --git a/test/test_metrics.py b/test/test_metrics.py index 8f5b0cfd850..d9b90418983 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -38,6 +38,14 @@ def test_clear_metrics(self): self.assertIn("TensorToData", met.metrics_report()) assert (len(met.metric_names()) > 0) + def test_tracing_time_metrics(self): + xla_device = xm.xla_device() + met.clear_all() + t1 = torch.tensor(156, device=xla_device) + t2 = t1 + 100 + self.assertIn('LazyTracing', met.metric_names()) + self.assertGreater(met.metric_data('LazyTracing')[0], 1) + def test_short_metrics_report_default_list(self): xla_device = xm.xla_device() t1 = torch.tensor(1456, device=xla_device) @@ -164,7 +172,9 @@ def test_metrics_report(self): self.assertIn("CachedCompile", report) @unittest.skipIf( + xm.get_xla_supported_devices("CUDA") or xm.get_xla_supported_devices("GPU") or + xm.get_xla_supported_devices("ROCM") or xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index b8fee7e29ba..3ffeebc963d 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -13,7 +13,7 @@ def all_gather(tensor, dim): def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 3dc45732ac3..eb292f7a53d 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -9,7 +9,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) diff --git a/test/test_operations.py b/test/test_operations.py index 6a1be16ede3..4e8ebcedee4 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -36,7 +36,7 @@ import torch_xla.debug.metrics as met import torch_xla.debug.model_comparator as mc import torch_xla.distributed.parallel_loader as pl -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs from torch_xla import runtime as xr import torch_xla.test.test_utils as xtu import torch_xla.utils.utils as xu @@ -58,7 +58,13 @@ def _is_on_tpu(): return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU' +def _is_on_eager_debug_mode(): + return xu.getenv_as('XLA_USE_EAGER_DEBUG_MODE', bool, defval=False) + + skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU') +skipOnEagerDebug = unittest.skipIf(_is_on_eager_debug_mode(), + 'skip on eager debug mode') def _gen_tensor(*args, **kwargs): @@ -309,6 +315,19 @@ def test_get_xla_tensor(self): tx = t.select(1, 12) self.assertEqual(tx, sx.data.cpu()) + def test_masked_fill_scalar(self): + + def fn(tensor): + # Build a mask from the first line of tensor. + # Also, make it have the same rank as the original tensor. + mask = tensor[0].ge(0.5).unsqueeze(dim=0) + # Call masked_fill. + return tensor.masked_fill(mask, 10) + + x = _gen_tensor(2, 2, device=xm.xla_device()) + x_cpu = x.cpu() + self.assertEqual(fn(x_cpu), fn(x)) + class TestRandom(test_utils.XlaTestCase): @@ -434,7 +453,8 @@ def test_get_real_xla_devices(self): devices = xm.get_xla_supported_devices() xla_devices = torch_xla._XLAC._xla_real_devices(devices) for device, xdevice in zip(devices, xla_devices): - self.assertTrue(re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None) + self.assertTrue( + re.match(r'(CPU|GPU|TPU|CUDA|ROCM):\d+$', xdevice) is not None) def test_negative_slice(self): t = _gen_tensor(32, 24, 32) @@ -970,7 +990,9 @@ def func(a, b): b = torch.ones([2, 2]) self.runAtenTest((a, b), func) - @unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION, + # TODO - upstream behavior has changed and results in expected DestroyXlaTensor + # counter as of 11/13/2023. Re-enable after reviewing the change. + @unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION, 'Metrics differ when functionalization is disabled.') def test_set(self): met.clear_all() @@ -1635,6 +1657,42 @@ def test_cached_addcdiv(self): xm.mark_step() self.assertEqual(met.metric_data("TransferToServerTime")[0], 4) + @skipOnEagerDebug + def test_print_executation(self): + xla_device = xm.xla_device() + xm.mark_step() + xm.wait_device_ops() + met.clear_all() + + # case 1 mark_step + t1 = torch.randn(1, 4, device=xla_device) + xm.mark_step() + xm.wait_device_ops() + self.assertEqual(met.metric_data('ExecuteTime')[0], 1) + for _ in range(3): + print(t1) + self.assertEqual(met.metric_data('ExecuteTime')[0], 1) + self.assertIn('xla::device_data', + torch_xla._XLAC._get_xla_tensors_text([t1])) + + # case 2 no mark_step, directly print + met.clear_all() + t1 = torch.randn(1, 4, device=xla_device) + for _ in range(3): + print(t1) + self.assertEqual(met.metric_data('ExecuteTime')[0], 1) + self.assertIn('xla::device_data', + torch_xla._XLAC._get_xla_tensors_text([t1])) + + # case 2 no mark_step, print with .cpu + met.clear_all() + t1 = torch.randn(1, 4, device=xla_device) + for _ in range(3): + print(t1.cpu()) + self.assertEqual(met.metric_data('ExecuteTime')[0], 1) + self.assertIn('xla::device_data', + torch_xla._XLAC._get_xla_tensors_text([t1])) + def test_index_types(self): def test_fn(*indices): diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 862272acd57..9fe8b7b4aae 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -60,9 +60,16 @@ def test_div_by_f64(self): [p.grad for p in mod.parameters() if p.requires_grad]) assert 'f64' not in hlo_text + def test_dropout_by_u8_mask(self): + mod = torch.nn.Dropout().to(xm.xla_device()) + a = torch.rand(20, 16, dtype=torch.bfloat16).to(xm.xla_device()) + b = mod(a) + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) + assert 'u8' in hlo_text + if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index 5e092b6c394..f70a380132e 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -198,7 +198,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags, dynamic_graph=True, fetch_often=True) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_pt_xla_debug.py b/test/test_pt_xla_debug.py new file mode 100644 index 00000000000..14f0817a4c0 --- /dev/null +++ b/test/test_pt_xla_debug.py @@ -0,0 +1,120 @@ +import os + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.utils.utils as xu +import torch_xla.debug.profiler as xp +import torch_xla.utils.utils as xu +import torch_xla.distributed.parallel_loader as pl +import unittest + + +def check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + + +def extract_execution_cause(lines): + causes = [] + for i in range(len(lines)): + if 'Execution Cause' in lines[i].decode(): + causes.append(lines[i + 1].decode()) + return causes + + +class PtXLADebugTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if not check_env_flag('PT_XLA_DEBUG'): + assert False, "This test should be run with PT_XLA_DEBUG" + cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE') + if not cls.debug_file_name: + assert False, "This test should be run with PT_XLA_DEBUG_FILE" + open(cls.debug_file_name, 'w').close() + + def test_user_mark_step(self): + device = xm.xla_device() + t1 = torch.randn(2, 2, device=device) + xm.mark_step() + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 1) + self.assertIn('user mark_step', causes[0]) + open(self.debug_file_name, 'w').close() + + def test_step_trace(self): + device = xm.xla_device() + with xp.StepTrace('train_pt_xla_debug'): + t1 = torch.randn(2, 2, device=device) + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 1) + self.assertIn('mark_step when exiting a profiler StepTrace region', + causes[0]) + open(self.debug_file_name, 'w').close() + + def test_dynamo(self): + device = xm.xla_device() + t1 = torch.randn(2, 2, device=device) + + def toy_program(t1): + return t1 + t1 + + compiled = torch.compile(toy_program, backend="openxla") + res = compiled(t1) + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 4) + self.assertIn('mark_step when dynamo processing input graphs', causes[0]) + self.assertIn('mark_step when dynamo processing input graphs', causes[1]) + self.assertIn('dynamo is compiling a FX graph to HLO', causes[2]) + self.assertIn('dynamo is executing a compiled program', causes[3]) + open(self.debug_file_name, 'w').close() + + def test_parallel_loader(self): + device = xm.xla_device() + + train_dataset_len = 100 + batch_size = 10 + train_loader = xu.SampleGenerator( + data=(torch.zeros(batch_size, 3, 128, + 128), torch.zeros(batch_size, dtype=torch.int64)), + sample_count=train_dataset_len // 10) + + train_device_loader = pl.MpDeviceLoader( + train_loader, + device, + loader_prefetch_size=8, + device_prefetch_size=4, + host_to_device_transfer_threads=1) + + for step, (data, target) in enumerate(train_device_loader): + pass + + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), batch_size + 2) + for cause in causes: + self.assertIn('mark_step in parallel loader at step end', cause) + open(self.debug_file_name, 'w').close() + + def test_print(self): + device = xm.xla_device() + t1 = torch.randn(2, 2, device=device) + print(t1) + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 1) + self.assertIn('user code trying to access tensor value', causes[0]) + open(self.debug_file_name, 'w').close() + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_torch_distributed_all_gather_xla_backend.py b/test/test_torch_distributed_all_gather_xla_backend.py index 763c15d6f5b..f75a019db86 100644 --- a/test/test_torch_distributed_all_gather_xla_backend.py +++ b/test/test_torch_distributed_all_gather_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_all_reduce_xla_backend.py b/test/test_torch_distributed_all_reduce_xla_backend.py index 3f0bca31b8f..9962c824b7d 100644 --- a/test/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/test_torch_distributed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 79b65a46999..c626faf7447 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -8,9 +8,9 @@ def _mp_fn(index): dev = xm.xla_device() - if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): + if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( - 'Default device {} is not a TPU or GPU device'.format(dev), + 'Default device {} is not a TPU or CUDA device'.format(dev), file=sys.stderr) return diff --git a/test/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/test_torch_distributed_multi_all_reduce_xla_backend.py index cf16311ca98..e576c3ffb0f 100644 --- a/test/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_torch_distributed_reduce_scatter_xla_backend.py b/test/test_torch_distributed_reduce_scatter_xla_backend.py index f278567379e..fd146d98af7 100644 --- a/test/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/test_torch_distributed_reduce_scatter_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU'): + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index 4584be37e1f..43c4c9603a1 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -31,8 +31,6 @@ '--ddp': { 'action': 'store_true', }, - # Use xla:// init_method instead of env:// for `torch.distributed`. - # Required for DDP on TPU v2/v3 when using PJRT. '--pjrt_distributed': { 'action': 'store_true', }, @@ -85,6 +83,7 @@ import torch_xla.distributed.parallel_loader as pl import torch_xla.debug.profiler as xp import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils @@ -179,11 +178,8 @@ def _train_update(device, step, loss, tracker, epoch, writer): def train_imagenet(): - if FLAGS.pjrt_distributed: + if FLAGS.ddp or FLAGS.pjrt_distributed: dist.init_process_group('xla', init_method='xla://') - elif FLAGS.ddp: - dist.init_process_group( - 'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal()) print('==> Preparing data..') img_dim = get_model_property('img_dim') @@ -365,7 +361,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, @@ -374,4 +370,7 @@ def _mp_fn(index, flags): if __name__ == '__main__': - xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) + if dist.is_torchelastic_launched(): + _mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS) + else: + xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index bec112a9378..ffcf6ee1386 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -221,7 +221,7 @@ def train_imagenet(): if FLAGS.amp: if device_hw == 'TPU': scaler = None - elif device_hw == 'GPU': + elif device_hw in ('GPU', 'CUDA', 'ROCM'): scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): @@ -298,7 +298,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index fdfdc8a698c..351e19aad75 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -385,7 +385,7 @@ def test_loop_fn(loader, epoch): def _mp_fn(index, flags): global FLAGS FLAGS = flags - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_imagenet() if accuracy < FLAGS.target_accuracy: print('Accuracy {} is below target {}'.format(accuracy, diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 22253fbea73..3b078d22fab 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -76,11 +76,8 @@ def _train_update(device, step, loss, tracker, epoch, writer): def train_mnist(flags, **kwargs): - if flags.pjrt_distributed: + if flags.ddp or flags.pjrt_distributed: dist.init_process_group('xla', init_method='xla://') - elif flags.ddp: - dist.init_process_group( - 'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal()) torch.manual_seed(1) @@ -209,7 +206,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index ae4db118300..990ea9bc91a 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -142,7 +142,7 @@ def train_mnist(flags, **kwargs): if device_hw == 'TPU': scaler = None - elif device_hw == 'GPU': + elif device_hw == 'CUDA': # GradScaler only used for GPU scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) else: @@ -211,7 +211,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 2bb549e72a4..96d9a9b8bbb 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -313,7 +313,7 @@ def test_loop_fn(model, loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 6f8d3964b52..02a6db04a17 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -184,7 +184,7 @@ def test_loop_fn(loader): def _mp_fn(index, flags): - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_dtype(torch.float32) accuracy = train_mnist(flags) if flags.tidy and os.path.isdir(flags.datadir): shutil.rmtree(flags.datadir) diff --git a/test/test_zero1.py b/test/test_zero1.py index cb751726577..e9c3a3eeee6 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -13,7 +13,7 @@ class XlaZeRO1Test(TestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") - @unittest.skipIf(xr.device_type() == 'GPU', + @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "TODO(alanwaketan): Fix it for the token change.") def test_zero1(self): device = xm.xla_device() diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index cfde8456617..7191b5d5bb9 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -3,6 +3,8 @@ import re import tempfile +from ._internal import tpu + logging.basicConfig() logger = logging.getLogger(__name__) @@ -30,17 +32,31 @@ def _setup_xla_flags(): os.environ['XLA_FLAGS'] = ' '.join(flags) -def _set_missing_env(name, value): - if name not in os.environ: - os.environ[name] = value +def _setup_libtpu_flags(): + flags = os.environ.get('LIBTPU_INIT_ARGS', '').split(' ') + # This flag will rerun the latency hidding scheduler if the default + # shared memory limit 95% leads to OOM. Each rerun will choose a value + # 0.9x of the previous run, and the number of rerun is set to 1 now. + # Shared memory limit refers to --xla_tpu_scheduler_percent_shared_memory_limit. + # Lower shared memory limit means less communiation and computation overlapping, + # and thus worse performance. + flags = _set_missing_flags(flags, + (('xla_latency_hiding_scheduler_rerun', '1'),)) + os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags) def _setup_default_env(): - _set_missing_env('TF_CPP_MIN_LOG_LEVEL', '1') - _set_missing_env('GRPC_VERBOSITY', 'ERROR') - _set_missing_env('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') - _set_missing_env('TPU_ML_PLATFORM', 'PyTorch/XLA') - _set_missing_env('TPU_MEGACORE', 'megacore_dense') + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') + + if tpu.num_available_chips() > 0: + _setup_libtpu_flags() + + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') + os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') + + if tpu.version() == 4: + os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') _fd, _tmp_fname = -1, '' @@ -48,7 +64,7 @@ def _setup_default_env(): def _setup_debug_env(): fd, tmp_fname = tempfile.mkstemp('.ptxla', text=True) - _set_missing_env('XLA_FNTRACKER_FILE', tmp_fname) + os.environ.setdefault('XLA_FNTRACKER_FILE', tmp_fname) return fd, tmp_fname @@ -71,13 +87,16 @@ def _aws_ec2_inf_trn_init(): def _setup_tpu_vm_library_path() -> bool: - """Returns true if $TPU_LIBRARY is set or can be inferred. + """Returns true if $TPU_LIBRARY_PATH is set or can be inferred. We load libtpu.so in the following order of precedence: 1. User-set $TPU_LIBRARY_PATH 2. libtpu.so included in torch_xla/lib 3. libtpu-nightly pip package + + Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts + with other frameworks. This env var will be removed in a future version. """ if 'TPU_LIBRARY_PATH' in os.environ: return True @@ -86,12 +105,12 @@ def _setup_tpu_vm_library_path() -> bool: bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so') if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'): logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path) - os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path + os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path return True try: import libtpu - libtpu.configure_library_path() + os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path() return True except ImportError: return False diff --git a/torch_xla/_internal/gpu.py b/torch_xla/_internal/gpu.py index 20d7fce91f4..ad73b32ce3e 100644 --- a/torch_xla/_internal/gpu.py +++ b/torch_xla/_internal/gpu.py @@ -1,9 +1,6 @@ import os -import torch_xla import torch_xla.core.xla_env_vars as xenv -distributed_service = None - def num_local_processes() -> int: """Returns number of processes to create on this host. @@ -14,30 +11,5 @@ def num_local_processes() -> int: """ assert xenv.GPU_NUM_DEVICES in os.environ, \ "Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client" - return int(os.environ[xenv.GPU_NUM_DEVICES]) - - -def initialize_distributed_runtime(global_world_size: int) -> None: - """Configures GPU distributed runtime parameters. - - Must be run before using any XLA devices. - - Args: - global_world_size: number of devices in the cluster. - """ - if global_world_size > 1: - # TODO(jonbolin): For multi-host, this needs to be consistent across hosts - os.environ.setdefault(xenv.PJRT_DIST_SERVICE_ADDR, '127.0.0.1:8547') - global distributed_service - if distributed_service is None: - num_nodes = global_world_size - distributed_service = torch_xla._XLAC._xla_get_distributed_runtime_service( - num_nodes) - - -def shutdown_distributed_runtime() -> None: - """Destroy the distributed runtime after a distributed computation.""" - global distributed_service - if distributed_service: - distributed_service.shutdown() - distributed_service = None + os.environ[xenv.LOCAL_WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES] + return int(os.environ[xenv.LOCAL_WORLD_SIZE]) diff --git a/torch_xla/_internal/neuron.py b/torch_xla/_internal/neuron.py index 2f286681136..7b145918675 100644 --- a/torch_xla/_internal/neuron.py +++ b/torch_xla/_internal/neuron.py @@ -1,4 +1,5 @@ import os +import logging def num_local_processes() -> int: diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 9e7533955e4..b92bba55679 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -7,12 +7,14 @@ from typing import Callable, Dict, List, Tuple, TypeVar import torch +import torch.distributed as dist import torch_xla import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_backend from torch_xla._internal import tpu, gpu, neuron from torch_xla import runtime +import torch_xla.utils.utils as xu R = TypeVar('R') @@ -138,9 +140,8 @@ def run_multiprocess(fn: Callable[..., R], """ if runtime.device_type() == 'TPU': num_processes = tpu.num_local_processes() - elif runtime.device_type() == 'GPU': + elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'): num_processes = gpu.num_local_processes() - gpu.initialize_distributed_runtime(num_processes) elif runtime.device_type() == 'NEURON': num_processes = neuron.num_local_processes() else: @@ -160,9 +161,6 @@ def run_multiprocess(fn: Callable[..., R], itertools.chain.from_iterable( result.items() for result in process_results)) - if runtime.device_type() == 'GPU': - gpu.shutdown_distributed_runtime() - return _merge_replica_results(replica_results) @@ -210,8 +208,8 @@ def _initialize_single_process(local_rank: int, local_world_size: int): def spawn_threads(fn: Callable, args: Tuple = ()) -> None: """Run function in one process with one thread per addressable device.""" - assert runtime.device_type( - ) != 'GPU', "spawn_threads does not support GPU device" + assert runtime.device_type() not in ( + 'GPU', 'ROCM', 'CUDA'), "spawn_threads does not support GPU device" spawn_fn = _SpawnFn(fn, *args) _run_thread_per_device( local_rank=0, diff --git a/torch_xla/_internal/rendezvous.py b/torch_xla/_internal/rendezvous.py index 8a95ce1024c..26bbae300a1 100644 --- a/torch_xla/_internal/rendezvous.py +++ b/torch_xla/_internal/rendezvous.py @@ -28,7 +28,6 @@ def pjrt_rendezvous_handler(url: str, ) == 'TPU' else 'localhost' master_port = xu.getenv_as('MASTER_PORT', int, 12355) - world_size = xr.world_size() with _store_lock: global _store if not _store: @@ -44,4 +43,8 @@ def pjrt_rendezvous_handler(url: str, xr.process_count(), is_master=xr.process_index() == 0) - yield (_store, xr.global_ordinal(), world_size) + # In SPMD, the world size and rank are determined by the process count and + # index, while in multiprocess they are based on the device count and ordinal. + world_size = xr.process_count() if xr.is_spmd() else xr.world_size() + rank = xr.process_index() if xr.is_spmd() else xr.global_ordinal() + yield (_store, rank, world_size) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 89fbca4dcbc..385566b1d35 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -1,15 +1,18 @@ import functools import glob +from ipaddress import ip_address import operator import os import pathlib import re +import socket from typing import NamedTuple, Optional, List from typing_extensions import TypedDict import requests import yaml import torch +import torch_xla import torch_xla.utils.utils as xu import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm @@ -268,7 +271,10 @@ def configure_topology(local_rank: int, def discover_master_worker_ip(use_localhost: bool = True) -> str: - """Find the IP of the TPU host with TPU:0. + """Find the IP of the master TPU host. + + In multiprocess, this is the host with TPU:0. + In SPMD mode, this is the host running process 0. TPU device IDs are nondeterministic and independent from Cloud TPU worker IDs. @@ -276,15 +282,40 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: use_localhost: if there is only one TPU host, return 'localhost` instead of that host's internal IP. """ + import torch_xla.runtime as xr worker_ips = get_worker_ips() if len(worker_ips) == 1: return 'localhost' tpu_env = get_tpu_env() current_worker_id = int(tpu_env[xenv.WORKER_ID]) + if xr.is_spmd(): + return _spmd_find_master_ip(worker_ips[current_worker_id]) + t = torch.tensor([current_worker_id], device=xm.xla_device()) xm.collective_broadcast([t]) xm.mark_step() master_worker_id = int(t.cpu()) return worker_ips[master_worker_id] + + +def _spmd_find_master_ip(current_worker_hostname: str) -> str: + import torch_xla.runtime as xr + import torch_xla.distributed.spmd as xs + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + # Translate the hostname to an IP address, e.g. for TPUs on GKE. + current_worker_ip = socket.gethostbyname(current_worker_hostname) + ip_int = int(ip_address(current_worker_ip)) + n_dev = xr.global_runtime_device_count() + local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices()) + # Create a global (n_dev x 2) tensor containing all process indices and IPs, + # and find the process 0 IP as the master IP. + shard = torch.LongTensor([[xr.process_index(), ip_int]]) + op_sharding = xs.Mesh(range(n_dev), (n_dev, 1)).get_op_sharding((0, 1)) + global_tensor = from_cpu_shards([shard] * local_ndev, op_sharding).cpu() + # Process 0 may not control device 0, so we must do a linear search. + for proc, ip in global_tensor.tolist(): + if proc == 0: + return str(ip_address(ip)) + raise RuntimeError('Could not find IP of host running process 0') diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index fcdd4a40840..db7dc3d5d9b 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -25,7 +25,7 @@ def __init__(self, self._enabled = enabled self._xla_device = xm.xla_device_hw(device) - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): backend = 'cuda' self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype. if dtype is None: @@ -70,7 +70,7 @@ def __init__(self, def __enter__(self): # This ensures that xla autocast is enabled even for XLA:GPU, which calls # `torch.amp.autocast_mode.autocast` with `cuda` backend. - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] self.prev_dtype = torch.get_autocast_xla_dtype( ) # type: ignore[attr-defined] @@ -86,7 +86,7 @@ def __enter__(self): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] - if self._xla_device == 'GPU': + if self._xla_device in ('GPU', 'ROCM', 'CUDA'): if self._xla_bfloat16: # autocast_xla flags will be set by `torch.autocast` and we need to # set autocast flags as we call into `torch.autocast` apis. diff --git a/torch_xla/amp/syncfree/adam.py b/torch_xla/amp/syncfree/adam.py index 1edee07238c..4201933ca59 100644 --- a/torch_xla/amp/syncfree/adam.py +++ b/torch_xla/amp/syncfree/adam.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +import torch_xla.core.xla_model as xm from . import _functional as F @@ -86,9 +87,14 @@ def step(self, closure=None, found_inf: Tensor = None): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( p, memory_format=torch.preserve_format) - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + else: + state['max_exp_avg_sq'] = torch.empty( + 0, dtype=torch.float, device=xm.xla_device()) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/amp/syncfree/adamw.py b/torch_xla/amp/syncfree/adamw.py index 60f3745d23a..83e11d46fad 100644 --- a/torch_xla/amp/syncfree/adamw.py +++ b/torch_xla/amp/syncfree/adamw.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +import torch_xla.core.xla_model as xm from . import _functional as F @@ -84,9 +85,14 @@ def step(self, closure=None, found_inf: Tensor = None): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( p, memory_format=torch.preserve_format) - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + else: + state['max_exp_avg_sq'] = torch.empty( + 0, dtype=torch.float, device=xm.xla_device()) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index f30f7d8ef8f..d9c13c6ec69 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -230,16 +230,19 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): for xla_arg in xla_args ] - args_tensor_ids = [ - torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args - ] + index_and_xla_tensor_args = [(i, xla_arg) + for i, xla_arg in enumerate(xla_args) + if isinstance(xla_arg, torch.Tensor)] + + index_and_tensor_ids = [(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg)) + for index, xla_arg in index_and_xla_tensor_args] if dynamo_debug: print(f"Graph module:\n{xla_model.code}") - print(f"args_tensor_ids {args_tensor_ids}") + print(f"args_tensor_ids {index_and_tensor_ids}") tensor_id_to_arg_idx = { - tensor_id: i for i, tensor_id in enumerate(args_tensor_ids) + tensor_id: index for index, tensor_id in index_and_tensor_ids } if xr.is_spmd(): @@ -258,15 +261,16 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): # If a arg is being in place updated by model, we need to include arg as part of the graph result. xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( - xla_args) + [tensor for _, tensor in index_and_xla_tensor_args]) xla_args_need_update = [] arg_index_to_need_update_index = {} for i, need_update in enumerate(xla_args_need_update_bool): # Don't add inplace updated argument to the list if it's already # being returned - if need_update and id(xla_args[i]) not in xla_out_ids: - arg_index_to_need_update_index[i] = len(xla_args_need_update) - xla_args_need_update.append(xla_args[i]) + index, tensor = index_and_xla_tensor_args[i] + if need_update and id(tensor) not in xla_out_ids: + arg_index_to_need_update_index[index] = len(xla_args_need_update) + xla_args_need_update.append(tensor) args_and_out = tuple(xla_args_need_update) + tuple(xla_out) @@ -325,7 +329,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): def extract_internal(xla_model: torch.fx.GraphModule): if dynamo_debug: for xla_arg in xla_model.xla_args: - print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) + if isinstance(xla_arg, torch.Tensor): + print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) xm.mark_step() (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, @@ -347,7 +352,9 @@ def optimized_mod(*args): # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. - if any(torch_xla._XLAC._check_tensor_need_materialization(args)): + if any( + torch_xla._XLAC._check_tensor_need_materialization( + [a for a in args if isinstance(a, torch.Tensor)])): xm.mark_step(wait=True) # If input sharding has changed from the previous program, dynamo current can diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index f67ea2d9fb6..eb79ff14310 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -26,3 +26,6 @@ PJRT_GPU_ASYNC_CLIENT = 'PJRT_GPU_ASYNC_CLIENT' PJRT_DIST_SERVICE_ADDR = 'PJRT_DIST_SERVICE_ADDR' LOCAL_RANK = 'LOCAL_RANK' +RANK = 'RANK' +WORLD_SIZE = 'WORLD_SIZE' +LOCAL_WORLD_SIZE = 'LOCAL_WORLD_SIZE' diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 7f9682c1000..e85db1d20a6 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -5,6 +5,7 @@ import re import threading import time +import warnings from typing import List, Optional import torch import torch.distributed._functional_collectives @@ -71,7 +72,7 @@ def is_xla_tensor(tensor): def parse_xla_device(device): - m = re.match(r'(CPU|TPU|GPU|XPU|NEURON):(\d+)$', device) + m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device) if m: return (m.group(1), int(m.group(2))) @@ -88,8 +89,17 @@ def get_xla_supported_devices(devkind=None, max_devices=None): Returns: The list of device strings. """ + # TODO(xiowei replace gpu with cuda): Remove the below if statement after r2.2 release. + if devkind and devkind.casefold() == 'gpu': + warnings.warn( + "GPU as a device name is being deprecate. Please replace it with CUDA such as get_xla_supported_devices(devkind='CUDA'). Similarly, please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA." + ) + devkind = 'CUDA' + xla_devices = _DEVICES.value - devkind = [devkind] if devkind else ['TPU', 'GPU', 'XPU', 'NEURON', 'CPU'] + devkind = [devkind] if devkind else [ + 'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM' + ] for kind in devkind: kind_devices = [] for i, device in enumerate(xla_devices): @@ -181,8 +191,8 @@ def xla_device(n=None, devkind=None): n (int, optional): The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise the first device of `devkind` will be returned. - devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU` - `NEURON` or `CPU`. + devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU` + `NEURON`, `ROCM` or `CPU`. Returns: A `torch.device` with the requested instance. @@ -217,7 +227,7 @@ def xla_device_hw(device): real device. Returns: - A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`) + A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`) of the given device. """ real_device = _xla_real_device(device) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 352db2d34fb..b18014ab2df 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -112,16 +112,15 @@ ptxla_cc_library( deps = [ ":aten_cpu_fallback", ":device", + ":dtype", ":einsum_utilities", ":ir", ":ir_builder", ":layout_manager", ":shape_builder", ":shape_helper", - "//torch_xla/csrc/runtime:async_task", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", - "//torch_xla/csrc/runtime:unique", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/hash", "@com_google_absl//absl/memory", @@ -171,6 +170,21 @@ ptxla_cc_library( ], ) +ptxla_cc_library( + name = "dtype", + srcs = ["dtype.cpp"], + hdrs = ["dtype.h"], + deps = [ + "//torch_xla/csrc:device", + "//torch_xla/csrc/runtime:tf_logging", + "//torch_xla/csrc/runtime:debug_macros", + "//torch_xla/csrc/runtime:sys_util", + "//torch_xla/csrc/runtime:util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + ptxla_cc_library( name = "layout_manager", srcs = ["layout_manager.cpp"], @@ -248,20 +262,21 @@ ptxla_cc_library( srcs = ["init_python_bindings.cpp"], deps = [ ":device", + ":dtype", ":tensor", ":version", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:metrics", "//torch_xla/csrc/runtime:metrics_analysis", "//torch_xla/csrc/runtime:metrics_reader", - "//torch_xla/csrc/runtime:multi_wait", "//torch_xla/csrc/runtime:profiler", "//torch_xla/csrc/runtime:sys_util", - "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc/runtime:util", + "//torch_xla/csrc/runtime:xla_coordinator", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:variant", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", @@ -304,6 +319,16 @@ cc_library( ], ) +cc_library( + name = "thread_pool", + srcs = ["thread_pool.cc"], + hdrs = ["thread_pool.h"], + deps = [ + "//torch_xla/csrc/runtime:sys_util", + "@tsl//tsl/platform:env" + ], +) + ptxla_cc_library( name = "unwrap_data", srcs = ["unwrap_data.cpp"], diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 08c6b27f92c..81cfdfb4f42 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); -} } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/aten_autograd_ops.h b/torch_xla/csrc/aten_autograd_ops.h index be063b76620..d1cc8a98048 100644 --- a/torch_xla/csrc/aten_autograd_ops.h +++ b/torch_xla/csrc/aten_autograd_ops.h @@ -46,6 +46,17 @@ struct MaxPool3dAutogradFunction torch::autograd::variable_list grad_output); }; +torch::Tensor max_pool2d_forward(torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, + torch::IntArrayRef dilation, bool ceil_mode); + +torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride, + torch::IntArrayRef padding, bool ceil_mode); + } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b51fbba98cd..466645af7c0 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -24,6 +24,7 @@ #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" #include "torch_xla/csrc/ops/as_strided_view_update.h" @@ -128,7 +129,7 @@ bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) { // Return the upper limit for a given type. For floating point typesreturn // 2^mantissa to ensure that every value is representable. int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) { - xla::PrimitiveType xla_type = TensorTypeToRawXlaType(dtype); + xla::PrimitiveType xla_type = XlaTypeFromTorchType(dtype); switch (xla_type) { case xla::PrimitiveType::F16: return static_cast(1) << std::numeric_limits::digits; @@ -150,7 +151,7 @@ void CheckRangeValues(torch::ScalarType dtype, int64_t from, int64_t to) { if (IsTypeWithLargerRangeThanLong(dtype)) { min_max = XlaHelpers::MinMaxValues(xla::PrimitiveType::S64); } else { - min_max = XlaHelpers::MinMaxValues(TensorTypeToRawXlaType(dtype)); + min_max = XlaHelpers::MinMaxValues(XlaTypeFromTorchType(dtype)); } XLA_CHECK_GE(from, min_max.min.toLong()); XLA_CHECK_LE(from, min_max.max.toLong()); @@ -252,7 +253,7 @@ void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other, at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::__ilshift__(self_tensor, other); return self; @@ -260,7 +261,7 @@ at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::__ilshift__(self_tensor, bridge::GetXlaTensor(other)); @@ -269,7 +270,7 @@ at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::__irshift__(self_tensor, other); @@ -278,7 +279,7 @@ at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::__irshift__(self_tensor, bridge::GetXlaTensor(other)); @@ -287,7 +288,7 @@ at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, at::Tensor XLANativeFunctions::__lshift__(const at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const at::Scalar& other, at::ScalarType dtype) { @@ -297,7 +298,7 @@ at::Tensor XLANativeFunctions::__lshift__(const at::Tensor& self, at::Tensor XLANativeFunctions::__lshift__(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother, at::ScalarType dtype) { @@ -307,7 +308,7 @@ at::Tensor XLANativeFunctions::__lshift__(const at::Tensor& self, at::Tensor XLANativeFunctions::__rshift__(const at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const at::Scalar& other, at::ScalarType dtype) { @@ -317,7 +318,7 @@ at::Tensor XLANativeFunctions::__rshift__(const at::Tensor& self, at::Tensor XLANativeFunctions::__rshift__(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother, at::ScalarType dtype) { @@ -327,7 +328,7 @@ at::Tensor XLANativeFunctions::__rshift__(const at::Tensor& self, at::Tensor XLANativeFunctions::_adaptive_avg_pool3d( const at::Tensor& self, at::IntArrayRef output_size) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto output_size_list = XlaHelpers::I64List(output_size); if (!IsSupportedAdaptivePool(XlaHelpers::I64List(self.sizes()), output_size_list, /*pool_dim=*/3)) { @@ -346,7 +347,7 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d( at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward( const at::Tensor& grad_output, const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); int64_t rank = grad_output.dim(); std::vector output_size{grad_output.size(rank - 3), grad_output.size(rank - 2), @@ -369,7 +370,7 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward( at::Tensor XLANativeFunctions::_adaptive_avg_pool2d( const at::Tensor& self, at::IntArrayRef output_size) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto output_size_list = XlaHelpers::I64List(output_size); if (!IsSupportedAdaptivePool(XlaHelpers::I64List(self.sizes()), output_size_list, /*pool_dim=*/2)) { @@ -383,7 +384,7 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool2d( at::Tensor XLANativeFunctions::_adaptive_avg_pool2d_backward( const at::Tensor& grad_output, const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); int64_t rank = grad_output.dim(); std::vector output_size{grad_output.size(rank - 2), grad_output.size(rank - 1)}; @@ -400,7 +401,7 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool2d_backward( std::tuple XLANativeFunctions::adaptive_max_pool2d( const at::Tensor& self, at::IntArrayRef output_size) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto output_size_list = XlaHelpers::I64List(output_size); if (!IsSupportedAdaptivePool(XlaHelpers::I64List(self.sizes()), output_size_list, /*pool_dim=*/2)) { @@ -418,7 +419,7 @@ std::tuple XLANativeFunctions::adaptive_max_pool2d( at::Tensor XLANativeFunctions::adaptive_max_pool2d_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& indices) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); int64_t rank = grad_output.dim(); std::vector output_size{grad_output.size(rank - 2), grad_output.size(rank - 1)}; @@ -435,7 +436,7 @@ at::Tensor XLANativeFunctions::adaptive_max_pool2d_backward( void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf); tensor_methods::_amp_foreach_non_finite_check_and_unscale_( bridge::GetXlaTensors(self), found_inf_tensor, @@ -448,7 +449,7 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker); XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale); tensor_methods::_amp_update_scale_( @@ -461,7 +462,7 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, const at::Tensor& dst, bool /*non_blocking*/) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto dst_tensor = bridge::TryGetXlaTensor(dst); auto self_tensor = bridge::TryGetXlaTensor(self); if (!self_tensor) { @@ -484,7 +485,7 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, at::Tensor XLANativeFunctions::_copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto dst_tensor = bridge::TryGetXlaTensor(dst); auto self_tensor = bridge::TryGetXlaTensor(self); if (!self_tensor) { @@ -506,7 +507,7 @@ at::Tensor XLANativeFunctions::_copy_from_and_resize(const at::Tensor& self, } std::vector XLANativeFunctions::_to_cpu(at::TensorList tensors) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::XlaCreateTensorList(tensors); } @@ -517,7 +518,7 @@ at::Tensor XLANativeFunctions::_to_copy( c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto options = self.options(); // I put each of these setters in a conditional instead of doing @@ -561,14 +562,14 @@ at::Tensor XLANativeFunctions::_to_copy( at::Tensor& XLANativeFunctions::_index_put_impl_( at::Tensor& self, const c10::List>& indices, const at::Tensor& values, bool accumulate, bool /* unsafe */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return torch_xla::XLANativeFunctions::index_put_(self, indices, values, accumulate); } std::tuple XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto outputs = tensor_methods::slogdet(self_tensor); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), @@ -579,7 +580,7 @@ XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) { at::Tensor XLANativeFunctions::_log_softmax(const at::Tensor& self, int64_t dim, bool half_to_float) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto self_meta = to_meta(self); auto out_meta = at::meta::_log_softmax(self_meta, dim, half_to_float); @@ -592,14 +593,14 @@ at::Tensor XLANativeFunctions::_log_softmax(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::_log_softmax_backward_data( const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, at::ScalarType /* input_dtype */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::log_softmax_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim)); } std::tuple XLANativeFunctions::_pack_padded_sequence( const at::Tensor& input, const at::Tensor& lengths, bool batch_first) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::vector xla_tensors = {lengths}; auto cpu_tensors = bridge::XlaCreateTensorList(xla_tensors); return at::native::_pack_padded_sequence(input, cpu_tensors[0], batch_first); @@ -607,7 +608,7 @@ std::tuple XLANativeFunctions::_pack_padded_sequence( at::Tensor XLANativeFunctions::_softmax(const at::Tensor& self, int64_t dim, bool /* half_to_float */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::softmax(bridge::GetXlaTensor(self), dim, c10::nullopt)); } @@ -615,21 +616,21 @@ at::Tensor XLANativeFunctions::_softmax(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::_softmax_backward_data( const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, at::ScalarType input_dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::softmax_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim)); } at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, at::IntArrayRef size) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Currently, we disallow the case when both operands contain dynamic // dimensions. This is consistent with PyTorch's behavior. XLA_CHECK(!(tensor_has_dym_dim(self) && tensor_has_dym_dim(other))) @@ -648,7 +649,7 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self, at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const at::Scalar& other, at::ScalarType dtype) { @@ -661,7 +662,7 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, const at::Tensor& mat2, const at::Scalar& beta, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (beta.to() != 1 || alpha.to() != 1) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(addmm)>::call(self, mat1, mat2, @@ -674,7 +675,7 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, } at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::alias(bridge::GetXlaTensor(self))); } @@ -683,7 +684,7 @@ at::Tensor& XLANativeFunctions::arange_out(const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, at::Tensor& out) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr out_tensor = bridge::GetXlaTensor(out); tensor_methods::arange_out(out_tensor, start, end, step, out.scalar_type()); return out; @@ -692,7 +693,7 @@ at::Tensor& XLANativeFunctions::arange_out(const at::Scalar& start, at::Tensor XLANativeFunctions::as_strided_copy( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); @@ -711,7 +712,7 @@ at::Tensor XLANativeFunctions::as_strided_scatter( const at::Tensor& base, const at::Tensor& mutated_view, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto base_ = bridge::GetXlaTensor(base); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); @@ -733,7 +734,7 @@ at::Tensor XLANativeFunctions::as_strided_scatter( at::Tensor XLANativeFunctions::atan2(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // xla::Atan2 doesn't support integer types. if (!self.is_floating_point() || !other.is_floating_point()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -754,7 +755,7 @@ at::Tensor XLANativeFunctions::avg_pool2d( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if ((ceil_mode && count_include_pad) || divisor_override) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP(avg_pool2d)>::call(self, kernel_size, stride, @@ -773,7 +774,7 @@ at::Tensor XLANativeFunctions::avg_pool2d_backward( at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if ((ceil_mode && count_include_pad) || divisor_override) { return at::native:: call_fallback_fn<&xla_cpu_fallback, ATEN_OP(avg_pool2d_backward)>::call( @@ -791,7 +792,7 @@ at::Tensor XLANativeFunctions::avg_pool3d( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if ((ceil_mode && count_include_pad) || divisor_override) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP(avg_pool3d)>::call(self, kernel_size, stride, @@ -810,7 +811,7 @@ at::Tensor XLANativeFunctions::avg_pool3d_backward( at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if ((ceil_mode && count_include_pad) || divisor_override) { return at::native:: call_fallback_fn<&xla_cpu_fallback, ATEN_OP(avg_pool3d_backward)>::call( @@ -829,7 +830,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, const at::Tensor& batch2, const at::Scalar& beta, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::baddbmm( bridge::GetXlaTensor(self), bridge::GetXlaTensor(batch1), @@ -838,7 +839,7 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, at::Tensor XLANativeFunctions::bernoulli( const at::Tensor& self, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(bernoulli)>::call(self, @@ -850,7 +851,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor XLANativeFunctions::bernoulli( const at::Tensor& self, double p, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP2(bernoulli, p)>::call(self, p, generator); @@ -862,7 +863,7 @@ at::Tensor XLANativeFunctions::bernoulli( at::Tensor& XLANativeFunctions::bernoulli_( at::Tensor& self, const at::Tensor& p, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP2(bernoulli_, Tensor)>::call(self, p, @@ -877,7 +878,7 @@ at::Tensor XLANativeFunctions::binary_cross_entropy_with_logits( const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return at::native::binary_cross_entropy_with_logits( self, target, IsDefined(weight) ? *weight : at::Tensor(), IsDefined(pos_weight) ? *pos_weight : at::Tensor(), reduction); @@ -885,7 +886,7 @@ at::Tensor XLANativeFunctions::binary_cross_entropy_with_logits( at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOpWithoutPromo( self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& other) { return tensor_methods::bitwise_and(xself, other); @@ -894,7 +895,7 @@ at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor& self, at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOpWithoutPromo( self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother) { return tensor_methods::bitwise_or(xself, xother); @@ -903,7 +904,7 @@ at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor& self, at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOpWithoutPromo( self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother) { return tensor_methods::bitwise_xor(xself, xother); @@ -912,28 +913,28 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self, at::Tensor XLANativeFunctions::bmm(const at::Tensor& self, const at::Tensor& mat2) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::bmm( bridge::GetXlaTensor(self), bridge::GetXlaTensor(mat2))); } at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::cat( bridge::GetXlaTensors(tensors), dim, at::native::result_type(tensors))); } at::Tensor XLANativeFunctions::celu(const at::Tensor& self, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::celu(bridge::GetXlaTensor(self), alpha)); } at::Tensor& XLANativeFunctions::celu_(at::Tensor& self, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::celu_(self_tensor, alpha); return self; @@ -942,21 +943,21 @@ at::Tensor& XLANativeFunctions::celu_(at::Tensor& self, at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, const c10::optional& min, const c10::optional& max) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::clamp(bridge::GetXlaTensor(self), min, max)); } at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self, const at::Scalar& max) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::clamp(bridge::GetXlaTensor(self), c10::nullopt, max)); } at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, const at::Scalar& min) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::clamp(bridge::GetXlaTensor(self), min, c10::nullopt)); } @@ -964,7 +965,7 @@ at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, at::Tensor XLANativeFunctions::clone( const at::Tensor& self, c10::optional /* memory_format */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::clone(bridge::GetXlaTensor(self))); } @@ -972,7 +973,7 @@ at::Tensor XLANativeFunctions::clone( at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self, at::IntArrayRef pad, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::constant_pad_nd( bridge::GetXlaTensor(self), XlaHelpers::I64List(pad), value)); } @@ -983,7 +984,7 @@ at::Tensor XLANativeFunctions::convolution_overrideable( const c10::optional& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (IsDefined(bias)) { return bridge::AtenFromXlaTensor(tensor_methods::convolution_overrideable( bridge::GetXlaTensor(input), bridge::GetXlaTensor(weight), @@ -1006,7 +1007,7 @@ XLANativeFunctions::convolution_backward_overrideable( const at::Tensor& weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, std::array output_mask) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto gradients = tensor_methods::convolution_backward_overrideable( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(input), bridge::GetXlaTensor(weight), XlaHelpers::I64List(stride), @@ -1023,13 +1024,13 @@ XLANativeFunctions::convolution_backward_overrideable( at::Tensor XLANativeFunctions::copy(const at::Tensor& self, const at::Tensor& src, bool non_blocking) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return _copy_from(src, self, non_blocking); } at::Tensor& XLANativeFunctions::copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); _copy_from(src, self, non_blocking); return self; } @@ -1037,7 +1038,7 @@ at::Tensor& XLANativeFunctions::copy_(at::Tensor& self, const at::Tensor& src, at::Tensor XLANativeFunctions::cross(const at::Tensor& self, const at::Tensor& other, c10::optional dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::cross( bridge::GetXlaTensor(self), bridge::GetXlaTensor(other), XlaHelpers::I64Optional(dim))); @@ -1045,7 +1046,7 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self, at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); c10::optional promoted_dtype = PromoteIntegralType(self_tensor->dtype(), dtype); @@ -1062,7 +1063,7 @@ at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::cumsum(const at::Tensor& self, int64_t dim, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); if (IsOperationOnType(dtype, self_tensor->dtype(), at::ScalarType::Long)) { // XLA reduce-window does not support S64 mode. @@ -1076,12 +1077,12 @@ at::Tensor XLANativeFunctions::cumsum(const at::Tensor& self, int64_t dim, // TODO(alanwaketan): Let's rewrite a without reusing other native functions. at::Tensor XLANativeFunctions::detach_copy(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(bridge::GetXlaTensor(self)); } at::Tensor XLANativeFunctions::diag(const at::Tensor& self, int64_t diagonal) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::diag(bridge::GetXlaTensor(self), diagonal)); } @@ -1089,7 +1090,7 @@ at::Tensor XLANativeFunctions::diag(const at::Tensor& self, int64_t diagonal) { at::Tensor XLANativeFunctions::diagonal_copy(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::diagonal(bridge::GetXlaTensor(self), offset, dim1, dim2)); } @@ -1115,7 +1116,7 @@ at::Tensor XLANativeFunctions::div(const at::Tensor& self, at::Tensor XLANativeFunctions::div( const at::Tensor& self, const at::Tensor& other, c10::optional rounding_mode) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); at::ScalarType dtype = at::result_type(self, other); auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype)); return bridge::AtenFromXlaTensor(tensor_methods::div( @@ -1124,14 +1125,14 @@ at::Tensor XLANativeFunctions::div( at::Tensor XLANativeFunctions::div(const at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::div(bridge::GetXlaTensor(self), other)); } at::Tensor XLANativeFunctions::dot(const at::Tensor& self, const at::Tensor& tensor) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK_EQ(self.dim(), 1) << "dot: Expected 1-D argument self, but got " << self.dim() << "-D"; XLA_CHECK_EQ(tensor.dim(), 1) @@ -1159,7 +1160,7 @@ at::Tensor XLANativeFunctions::einsum(c10::string_view equation, } } - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Einsum operations with more than 2 operands, like bilinear operations, are // not currently supported in XLA if (tensors.size() < 1 || tensors.size() > 2 || !all_xla_tensors_are_valid || @@ -1178,7 +1179,7 @@ at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output, const at::Scalar& input_scale, bool self, const at::Tensor& self_or_result) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(!self || alpha.to() >= 0.0) << "In-place elu backward calculation is triggered with a negative slope " "which is not supported."; @@ -1190,7 +1191,7 @@ at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::embedding_dense_backward( const at::Tensor& grad_output, const at::Tensor& indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::embedding_dense_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(indices), num_weights, padding_idx, scale_grad_by_freq)); @@ -1201,7 +1202,7 @@ at::Tensor XLANativeFunctions::empty_symint( c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional /* memory_format */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); c10::optional int_sizes = c10::asIntArrayRefSlowOpt(sym_size); bool all_dims_static = int_sizes.has_value(); @@ -1223,7 +1224,7 @@ at::Tensor XLANativeFunctions::empty_strided_symint( at::SymIntArrayRef sym_size, at::SymIntArrayRef sym_stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); c10::optional size = c10::asIntArrayRefSlowOpt(sym_size); bool is_size_dynamic = !size.has_value(); c10::optional stride = c10::asIntArrayRefSlowOpt(sym_stride); @@ -1239,7 +1240,7 @@ at::Tensor XLANativeFunctions::empty_strided_symint( at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, at::SymIntArrayRef sym_size, bool implicit) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); c10::optional size = c10::asIntArrayRefSlowOpt(sym_size); if (size.has_value()) { return bridge::AtenFromXlaTensor(tensor_methods::expand( @@ -1254,7 +1255,7 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, at::Tensor& XLANativeFunctions::exponential_( at::Tensor& self, double lambd, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(exponential_)>::call(self, @@ -1268,14 +1269,14 @@ at::Tensor& XLANativeFunctions::exponential_( } at::Tensor& XLANativeFunctions::eye_out(int64_t n, at::Tensor& out) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr out_tensor = bridge::GetXlaTensor(out); tensor_methods::eye_out(out_tensor, n, n); return out; } at::Tensor& XLANativeFunctions::eye_out(int64_t n, int64_t m, at::Tensor& out) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr out_tensor = bridge::GetXlaTensor(out); tensor_methods::eye_out(out_tensor, n, m); return out; @@ -1283,7 +1284,7 @@ at::Tensor& XLANativeFunctions::eye_out(int64_t n, int64_t m, at::Tensor& out) { at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::fill_(self_tensor, value); return self; @@ -1291,7 +1292,7 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, const at::Tensor& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK_EQ(value.dim(), 0) << "fill_ only supports a 0-dimensional " << "value tensor, but got tensor " << "with " << value.dim() << " dimension(s)."; @@ -1300,7 +1301,7 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, at::Tensor XLANativeFunctions::flip(const at::Tensor& self, at::IntArrayRef dims) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::flip( bridge::GetXlaTensor(self), XlaHelpers::I64List(dims))); } @@ -1313,7 +1314,7 @@ at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self, at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother, at::ScalarType dtype) { @@ -1323,7 +1324,7 @@ at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const at::Scalar& other, at::ScalarType dtype) { @@ -1331,17 +1332,35 @@ at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, }); } +at::Tensor XLANativeFunctions::full(at::IntArrayRef size, + const at::Scalar& fill_value, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + TORCH_LAZY_FN_COUNTER("xla::"); + // Fall back to CPU if layout or pin_memory are not default + if (layout.value_or(at::Layout::Strided) != at::Layout::Strided || + pin_memory.value_or(false)) { + return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( + size, fill_value, dtype, layout, device, pin_memory); + } + return bridge::AtenFromXlaTensor(tensor_methods::full( + absl::Span(size), fill_value, + GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); +} + at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::gather( bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index))); } at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, c10::string_view approximate) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::gelu(bridge::GetXlaTensor(self), approximate)); } @@ -1349,7 +1368,7 @@ at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, const at::Tensor& self, c10::string_view approximate) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::gelu_backward( bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate)); } @@ -1357,7 +1376,7 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::clamp(bridge::GetXlaTensor(self), min_val, max_val)); } @@ -1366,7 +1385,7 @@ at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::hardtanh_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), min_val, max_val)); @@ -1375,7 +1394,7 @@ at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::index( const at::Tensor& self, const c10::List>& indices) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); bool indices_on_cpu_or_xla = std::all_of(indices.begin(), indices.end(), [=](const c10::optional& opt) { @@ -1405,7 +1424,7 @@ at::Tensor XLANativeFunctions::index_add(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& source, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::index_add( bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index), bridge::GetXlaTensor(source), alpha)); @@ -1414,7 +1433,7 @@ at::Tensor XLANativeFunctions::index_add(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::index_copy(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& source) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor( tensor_methods::index_copy(self_tensor, dim, bridge::GetXlaTensor(index), @@ -1424,7 +1443,7 @@ at::Tensor XLANativeFunctions::index_copy(const at::Tensor& self, int64_t dim, at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::index_fill_(self_tensor, dim, bridge::GetXlaTensor(index), value); @@ -1434,7 +1453,7 @@ at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::index_fill_(self_tensor, dim, bridge::GetXlaTensor(index), bridge::GetXlaTensor(value)); @@ -1444,7 +1463,7 @@ at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, at::Tensor& XLANativeFunctions::index_put_( at::Tensor& self, const c10::List>& indices, const at::Tensor& values, bool accumulate) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); bool indices_on_cpu_or_xla = std::all_of(indices.begin(), indices.end(), [=](const c10::optional& opt) { @@ -1478,7 +1497,7 @@ at::Tensor& XLANativeFunctions::index_put_( at::Tensor XLANativeFunctions::index_select(const at::Tensor& self, int64_t dim, const at::Tensor& index) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::index_select( bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index))); } @@ -1486,13 +1505,13 @@ at::Tensor XLANativeFunctions::index_select(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::kl_div(const at::Tensor& self, const at::Tensor& target, int64_t reduction, bool log_target) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return at::native::kl_div(self, target, reduction, log_target); } std::tuple XLANativeFunctions::kthvalue( const at::Tensor& self, int64_t k, int64_t dim, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::kthvalue(bridge::GetXlaTensor(self), k, dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), @@ -1502,7 +1521,7 @@ std::tuple XLANativeFunctions::kthvalue( at::Tensor XLANativeFunctions::leaky_relu_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& negative_slope, bool self_is_result) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(!self_is_result || negative_slope.to() >= 0.0); auto common_device = torch_xla::bridge::GetXlaDevice(self); XLA_CHECK(common_device); @@ -1520,7 +1539,7 @@ at::Tensor XLANativeFunctions::leaky_relu_backward( at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, const at::Tensor& end, const at::Tensor& weight) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK_EQ(self.dtype(), end.dtype()) << "expected dtype " << self.dtype() << " for `end` but got dtype " << end.dtype(); @@ -1535,7 +1554,7 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, const at::Tensor& end, const at::Scalar& weight) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK_EQ(self.dtype(), end.dtype()) << "expected dtype " << self.dtype() << " for `end` but got dtype " << end.dtype(); @@ -1544,14 +1563,14 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, } at::Tensor XLANativeFunctions::lift(const at::Tensor& tensor) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return MaybeWrapTensorToFunctional(tensor); } at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& tensor) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return MaybeWrapTensorToFunctional(tensor); @@ -1559,7 +1578,7 @@ at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& tensor) { std::tuple XLANativeFunctions::linalg_inv_ex( const at::Tensor& self, bool check_errors) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // The default value for `check_errors` is False. And for now, we don't // do anything differently based on this flag. So when it's set to True, // we'll fallback to CPU. @@ -1584,7 +1603,7 @@ at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, c10::optional layout, c10::optional device, c10::optional pin_memory) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Fall back to CPU if layout or pin_memory are not default if (layout.value_or(at::Layout::Strided) != at::Layout::Strided || pin_memory.value_or(false)) { @@ -1601,32 +1620,32 @@ at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, } at::Tensor XLANativeFunctions::log(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::log(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::log10(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::log_base( bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log10), 10.0)); } at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::log1p(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::log2(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::log_base( bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log2), 2.0)); } at::Tensor XLANativeFunctions::logsumexp(const at::Tensor& self, at::IntArrayRef dim, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::logsumexp( bridge::GetXlaTensor(self), torch::lazy::ToVector(dim), /*keep_reduced_dimensions=*/keepdim)); @@ -1634,7 +1653,7 @@ at::Tensor XLANativeFunctions::logsumexp(const at::Tensor& self, at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::xlogy( bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } @@ -1642,7 +1661,7 @@ at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self, at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self, const at::Tensor& mask, const at::Tensor& source) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::masked_scatter( self_tensor, bridge::GetXlaTensor(mask), bridge::GetXlaTensor(source))); @@ -1650,7 +1669,7 @@ at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self, at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, const at::Tensor& mask) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); // Initially make XLA handled masked_select() handling experimental, and // opt-in. @@ -1664,14 +1683,14 @@ at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, } at::Tensor XLANativeFunctions::max(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::max(bridge::GetXlaTensor(self))); } std::tuple XLANativeFunctions::max( const at::Tensor& self, int64_t dim, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto outputs = tensor_methods::max(bridge::GetXlaTensor(self), dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); @@ -1680,7 +1699,7 @@ std::tuple XLANativeFunctions::max( std::tuple XLANativeFunctions::max_out( const at::Tensor& self, int64_t dim, bool keepdim, at::Tensor& max, at::Tensor& max_values) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr max_tensor = bridge::GetXlaTensor(max); XLATensorPtr max_values_tensor = bridge::GetXlaTensor(max_values); tensor_methods::max_out(max_tensor, max_values_tensor, @@ -1691,7 +1710,7 @@ std::tuple XLANativeFunctions::max_out( at::Tensor XLANativeFunctions::max_pool2d( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return aten_autograd_ops::MaxPool2dAutogradFunction::apply( self, kernel_size, stride, padding, dilation, ceil_mode); } @@ -1699,7 +1718,7 @@ at::Tensor XLANativeFunctions::max_pool2d( std::tuple XLANativeFunctions::max_pool2d_with_indices( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Lowering when ceil_mode or dilation is set not supported yet. if (IsNonTrivialDilation(dilation)) { return at::native::call_fallback_fn< @@ -1723,7 +1742,7 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward( at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor& indices) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Lowering when ceil_mode or dilation is set not supported yet. if (IsNonTrivialDilation(dilation)) { return at::native::call_fallback_fn< @@ -1742,7 +1761,7 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward( at::Tensor XLANativeFunctions::max_pool3d( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return aten_autograd_ops::MaxPool3dAutogradFunction::apply( self, kernel_size, stride, padding, dilation, ceil_mode); } @@ -1752,7 +1771,7 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward( at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor& indices) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Lowering when ceil_mode or dilation is set not supported yet. if (IsNonTrivialDilation(dilation)) { return at::native::call_fallback_fn< @@ -1771,7 +1790,7 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward( std::tuple XLANativeFunctions::max_pool3d_with_indices( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Lowering when ceil_mode or dilation is set not supported yet. if (IsNonTrivialDilation(dilation)) { return at::native::call_fallback_fn< @@ -1793,7 +1812,7 @@ std::tuple XLANativeFunctions::max_pool3d_with_indices( at::Tensor XLANativeFunctions::max_unpool2d(const at::Tensor& self, const at::Tensor& indices, at::IntArrayRef output_size) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::max_unpool( bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices), torch::lazy::ToVector(output_size))); @@ -1804,7 +1823,7 @@ at::Tensor XLANativeFunctions::max_unpool3d(const at::Tensor& self, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::max_unpool( bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices), torch::lazy::ToVector(output_size))); @@ -1812,7 +1831,7 @@ at::Tensor XLANativeFunctions::max_unpool3d(const at::Tensor& self, at::Tensor XLANativeFunctions::mean(const at::Tensor& self, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::mean( self_tensor, @@ -1823,7 +1842,7 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::mean( self_tensor, @@ -1833,21 +1852,21 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self, } at::Tensor XLANativeFunctions::min(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::min(bridge::GetXlaTensor(self))); } std::tuple XLANativeFunctions::min( const at::Tensor& self, int64_t dim, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto outputs = tensor_methods::min(bridge::GetXlaTensor(self), dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } at::Tensor XLANativeFunctions::mish(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::mish(bridge::GetXlaTensor(self))); } @@ -1855,7 +1874,7 @@ at::Tensor XLANativeFunctions::mish(const at::Tensor& self) { std::tuple XLANativeFunctions::min_out( const at::Tensor& self, int64_t dim, bool keepdim, at::Tensor& min, at::Tensor& min_indices) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr min_tensor = bridge::GetXlaTensor(min); XLATensorPtr min_indices_tensor = bridge::GetXlaTensor(min_indices); tensor_methods::min_out(min_tensor, min_indices_tensor, @@ -1865,7 +1884,7 @@ std::tuple XLANativeFunctions::min_out( at::Tensor XLANativeFunctions::mm(const at::Tensor& self, const at::Tensor& mat2) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::mm(/*input=*/bridge::GetXlaTensor(self), /*weight=*/bridge::GetXlaTensor(mat2))); @@ -1874,7 +1893,7 @@ at::Tensor XLANativeFunctions::mm(const at::Tensor& self, at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self, const at::Tensor& target, int64_t reduction) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::mse_loss( bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction)); } @@ -1883,7 +1902,7 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, int64_t reduction) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::mse_loss_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction)); @@ -1891,7 +1910,7 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother, at::ScalarType dtype) { @@ -1901,7 +1920,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const at::Scalar& other, at::ScalarType dtype) { @@ -1916,7 +1935,7 @@ at::Tensor XLANativeFunctions::multinomial( << "Multinomial number of samples must be greater than 0"; XLA_CHECK(at::isFloatingType(self.scalar_type())) << "Multinomial input must be a floating type"; - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Fallback when sampling is not replaced because it is challenging to // parallelize. See https://github.com/pytorch/xla/issues/4865 if ((generator.has_value() && generator->defined()) || @@ -1934,14 +1953,14 @@ at::Tensor XLANativeFunctions::multinomial( at::Tensor XLANativeFunctions::mv(const at::Tensor& self, const at::Tensor& vec) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::mv( bridge::GetXlaTensor(self), bridge::GetXlaTensor(vec))); } at::Tensor& XLANativeFunctions::mv_out(const at::Tensor& self, const at::Tensor& vec, at::Tensor& out) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr out_tensor = bridge::GetXlaTensor(out); tensor_methods::mv_out(out_tensor, bridge::GetXlaTensor(self), bridge::GetXlaTensor(vec)); @@ -1952,7 +1971,7 @@ at::Tensor XLANativeFunctions::nan_to_num(const at::Tensor& self, c10::optional nan, c10::optional posinf, c10::optional neginf) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // nan_to_num doesn't apply to integer types. if (!at::native::is_floating_point(self)) { return torch::lazy::CopyTensor(self); @@ -1983,7 +2002,7 @@ XLANativeFunctions::native_batch_norm( const c10::optional& running_mean, const c10::optional& running_var, bool training, double momentum, double eps) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr input_tensor = bridge::GetXlaTensor(input); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); XLATensorPtr running_mean_tensor = @@ -2004,7 +2023,7 @@ XLANativeFunctions::_native_batch_norm_legit( const at::Tensor& input, const c10::optional& weight, const c10::optional& bias, at::Tensor& running_mean, at::Tensor& running_var, bool training, double momentum, double eps) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr input_tensor = bridge::GetXlaTensor(input); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); XLATensorPtr running_mean_tensor = bridge::GetXlaTensor(running_mean); @@ -2023,7 +2042,7 @@ XLANativeFunctions::_native_batch_norm_legit( const at::Tensor& input, const c10::optional& weight, const c10::optional& bias, bool training, double momentum, double eps) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr input_tensor = bridge::GetXlaTensor(input); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); XLATensorPtr null_running_mean_tensor = XLATensorPtr(); @@ -2046,7 +2065,7 @@ XLANativeFunctions::native_batch_norm_backward( const c10::optional& save_mean, const c10::optional& save_invstd, bool train, double eps, std::array output_mask) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr grad_out_tensor = bridge::GetXlaTensor(grad_out); const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice(); auto gradients = tensor_methods::native_batch_norm_backward( @@ -2066,7 +2085,7 @@ XLANativeFunctions::native_batch_norm_backward( std::tuple XLANativeFunctions::native_dropout( const at::Tensor& self, double p, c10::optional train) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto results = tensor_methods::native_dropout(self_tensor, p, train); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), @@ -2074,7 +2093,7 @@ std::tuple XLANativeFunctions::native_dropout( } at::Tensor XLANativeFunctions::neg(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(self.scalar_type() != at::kBool) << "Negation, the `-` operator, on a bool tensor is not supported. If " "you are trying to invert a mask, use the `~` or `logical_not()` " @@ -2087,7 +2106,7 @@ at::Tensor XLANativeFunctions::nll_loss2d_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); XLATensorPtr weight_tensor = bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()); @@ -2106,7 +2125,7 @@ std::tuple XLANativeFunctions::nll_loss2d_forward( const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, int64_t reduction, int64_t ignore_index) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); XLATensorPtr total_weight = tensor_methods::full( {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); @@ -2122,7 +2141,7 @@ at::Tensor XLANativeFunctions::nll_loss_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); XLATensorPtr weight_tensor = bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()); @@ -2141,7 +2160,7 @@ std::tuple XLANativeFunctions::nll_loss_forward( const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, int64_t reduction, int64_t ignore_index) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); XLATensorPtr total_weight = tensor_methods::full( {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); @@ -2154,7 +2173,7 @@ std::tuple XLANativeFunctions::nll_loss_forward( } at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); // Initially make XLA handled nonzero() handling experimental, and opt-in. if (!DebugUtil::ExperimentEnabled("nonzero")) { @@ -2167,7 +2186,7 @@ at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) { at::Tensor XLANativeFunctions::norm(const at::Tensor& self, const c10::optional& p, at::ScalarType dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // If p==0 it is a torch.nonzero(), which is not lowered to XLA due to dynamic // shapes issue. if (p.has_value() && p->toDouble() == 0) { @@ -2181,7 +2200,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, at::Tensor XLANativeFunctions::norm(const at::Tensor& self, const at::Scalar& p) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // If p==0 it is a torch.nonzero(), which is not lowered to XLA due to dynamic // shapes issue. if (p.toDouble() == 0) { @@ -2196,7 +2215,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, const c10::optional& p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // If p==0 it is a torch.nonzero(), which is not lowered to XLA due to dynamic // shapes issue. if (p.has_value() && p->toDouble() == 0) { @@ -2213,7 +2232,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, at::Tensor XLANativeFunctions::norm(const at::Tensor& self, const c10::optional& p, at::IntArrayRef dim, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // If p==0 it is a torch.nonzero(), which is not lowered to XLA due to dynamic // shapes issue. if (p.has_value() && p->toDouble() == 0) { @@ -2227,7 +2246,7 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP2(normal, Tensor_float)>::call(mean, std, @@ -2239,7 +2258,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP2(normal, float_Tensor)>::call(mean, std, @@ -2252,7 +2271,7 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, const at::Tensor& std, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP2(normal, Tensor_Tensor)>::call(mean, std, @@ -2265,7 +2284,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, at::Tensor& XLANativeFunctions::normal_( at::Tensor& self, double mean, double std, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(normal_)>::call(self, mean, std, @@ -2278,14 +2297,14 @@ at::Tensor& XLANativeFunctions::normal_( at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self, at::IntArrayRef dims) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::permute( bridge::GetXlaTensor(self), XlaHelpers::I64List(dims))); } at::Tensor XLANativeFunctions::pow(const at::Tensor& self, const at::Scalar& exponent) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // xla::Pow() doesn't support integer types. if (!at::native::is_floating_point(self)) { return at::native::call_fallback_fn< @@ -2297,7 +2316,7 @@ at::Tensor XLANativeFunctions::pow(const at::Tensor& self, at::Tensor XLANativeFunctions::pow(const at::Tensor& self, const at::Tensor& exponent) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // xla::Pow() doesn't support integer types. if (!at::native::is_floating_point(self)) { return at::native::call_fallback_fn< @@ -2309,7 +2328,7 @@ at::Tensor XLANativeFunctions::pow(const at::Tensor& self, at::Tensor XLANativeFunctions::pow(const at::Scalar& self, const at::Tensor& exponent) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // xla::Pow() doesn't support integer types. if (!self.isFloatingPoint()) { return at::native::call_fallback_fn<&xla_cpu_fallback, @@ -2322,7 +2341,7 @@ at::Tensor XLANativeFunctions::pow(const at::Scalar& self, at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self, const at::Tensor& weight) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // If multiple weights, check channel size == number of weights. int64_t weight_num = weight.numel(); if (weight.numel() > 1) { @@ -2343,9 +2362,24 @@ at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self, tensor_methods::prelu(self_tensor, weight_tensor)); } +std::tuple XLANativeFunctions::_prelu_kernel_backward( + const at::Tensor& grad_output, const at::Tensor& self, + const at::Tensor& weight) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + + XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr weight_tensor = bridge::GetXlaTensor(weight); + + auto outputs = tensor_methods::prelu_backward(grad_output_tensor, self_tensor, + weight_tensor); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), + bridge::AtenFromXlaTensor(std::get<1>(outputs))); +} + at::Tensor XLANativeFunctions::prod(const at::Tensor& self, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::prod( self_tensor, @@ -2357,7 +2391,7 @@ at::Tensor XLANativeFunctions::prod(const at::Tensor& self, at::Tensor XLANativeFunctions::prod(const at::Tensor& self, int64_t dim, bool keepdim, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::prod(bridge::GetXlaTensor(self), {dim}, keepdim, PromoteIntegralType(self.scalar_type(), dtype))); @@ -2365,7 +2399,7 @@ at::Tensor XLANativeFunctions::prod(const at::Tensor& self, int64_t dim, void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, const at::Tensor& output) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // This op is only called when functionalize pass is transforming an in-place // op. Therefore, we can populate some meta data to maintain any optimization // for in-place ops we have in hands. @@ -2393,7 +2427,7 @@ at::Tensor& XLANativeFunctions::put_(at::Tensor& self, const at::Tensor& index, std::tuple XLANativeFunctions::qr( const at::Tensor& self, bool some) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::qr(bridge::GetXlaTensor(self), some); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); @@ -2403,7 +2437,7 @@ std::tuple XLANativeFunctions::qr( at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t from, c10::optional to, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn< &xla_cpu_fallback, ATEN_OP2(random_, from)>::call(self, from, to, @@ -2423,7 +2457,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (0, to]. at::Tensor& XLANativeFunctions::random_( at::Tensor& self, int64_t to, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP2(random_, to)>::call(self, to, @@ -2439,7 +2473,7 @@ at::Tensor& XLANativeFunctions::random_( // The value generated should be in (self_type_min, self_type_max). at::Tensor& XLANativeFunctions::random_( at::Tensor& self, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(random_)>::call(self, @@ -2456,7 +2490,7 @@ at::Tensor& XLANativeFunctions::random_( at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad2d( bridge::GetXlaTensor(self), torch::lazy::ToVector(padding))); } @@ -2464,7 +2498,7 @@ at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self, at::Tensor XLANativeFunctions::reflection_pad2d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad2d_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), torch::lazy::ToVector(padding))); @@ -2472,21 +2506,21 @@ at::Tensor XLANativeFunctions::reflection_pad2d_backward( at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::remainder( bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Scalar& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::remainder(bridge::GetXlaTensor(self), other)); } at::Tensor XLANativeFunctions::replication_pad1d(const at::Tensor& self, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad1d( bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); } @@ -2494,7 +2528,7 @@ at::Tensor XLANativeFunctions::replication_pad1d(const at::Tensor& self, at::Tensor XLANativeFunctions::replication_pad1d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad1d_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); @@ -2502,7 +2536,7 @@ at::Tensor XLANativeFunctions::replication_pad1d_backward( at::Tensor XLANativeFunctions::replication_pad2d(const at::Tensor& self, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad2d( bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); } @@ -2510,7 +2544,7 @@ at::Tensor XLANativeFunctions::replication_pad2d(const at::Tensor& self, at::Tensor XLANativeFunctions::replication_pad2d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad2d_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); @@ -2519,7 +2553,7 @@ at::Tensor XLANativeFunctions::replication_pad2d_backward( const at::Tensor& XLANativeFunctions::resize_( const at::Tensor& self, at::IntArrayRef size, c10::optional /* memory_format */) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::resize_(self_tensor, XlaHelpers::I64List(size)); return self; @@ -2528,7 +2562,7 @@ const at::Tensor& XLANativeFunctions::resize_( at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::roll( bridge::GetXlaTensor(self), XlaHelpers::I64List(shifts), XlaHelpers::I64List(dims))); @@ -2538,7 +2572,7 @@ at::Tensor XLANativeFunctions::rrelu_with_noise( const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { // The fallback path for rrelu_with_noise when training=true is wrong XLA_CHECK_EQ(training, false); @@ -2556,7 +2590,7 @@ at::Tensor XLANativeFunctions::rrelu_with_noise_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, bool self_is_result) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); double negative_slope = (lower.to() + upper.to()) / 2; XLA_CHECK(!self_is_result || negative_slope > 0.0); XLATensorPtr noise_tensor = bridge::GetXlaTensor(noise); @@ -2568,7 +2602,7 @@ at::Tensor XLANativeFunctions::rrelu_with_noise_backward( at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckSubOperandTypes(self.scalar_type(), other.scalar_type()); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother, @@ -2580,7 +2614,7 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckSubOperandTypes(self.scalar_type(), GetScalarType(other)); return bridge::AtenFromXlaTensor( tensor_methods::rsub(bridge::GetXlaTensor(self), other, alpha)); @@ -2610,7 +2644,7 @@ at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Scalar& value, c10::optional reduce) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); if (!reduce.has_value()) { return bridge::AtenFromXlaTensor(tensor_methods::scatter( @@ -2630,14 +2664,14 @@ at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return scatter_reduce_helper(self, dim, index, src, c10::nullopt); } at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return scatter_reduce_helper(self, dim, index, value, c10::nullopt); } @@ -2645,7 +2679,7 @@ at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src, c10::string_view reduce) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return scatter_reduce_helper(self, dim, index, src, reduce); } @@ -2653,14 +2687,14 @@ at::Tensor XLANativeFunctions::scatter(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Scalar& value, c10::string_view reduce) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return scatter_reduce_helper(self, dim, index, value, reduce); } at::Tensor XLANativeFunctions::scatter_add(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return scatter_reduce_helper(self, dim, index, src, "add"); } @@ -2669,7 +2703,7 @@ at::Tensor XLANativeFunctions::scatter_add(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::scatter_reduce( const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src, c10::string_view reduce, bool include_self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if ((reduce == "sum" || reduce == "prod" || reduce == "amin" || reduce == "amax") && include_self) { @@ -2687,7 +2721,7 @@ at::Tensor XLANativeFunctions::scatter_reduce( at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim, int64_t index) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::select(bridge::GetXlaTensor(self), dim, index)); } @@ -2695,7 +2729,7 @@ at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, int64_t index) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto base_tensor = bridge::GetXlaTensor(base); auto base_tensor_shape = base_tensor->shape(); auto mutated_view_tensor = bridge::GetXlaTensor(mutated_view); @@ -2723,7 +2757,7 @@ at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base, // TODO(JackCaoG): Remove after elu being codegened at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::selu_(self_tensor); return self; @@ -2731,21 +2765,21 @@ at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { at::Tensor& XLANativeFunctions::set_(at::Tensor& self, const at::Tensor& source) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr source_tensor = bridge::GetXlaTensor(source); bridge::ReplaceXlaTensor(self, source_tensor); return self; } at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::sigmoid(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::sigmoid_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output))); } @@ -2754,7 +2788,7 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, c10::optional start, c10::optional end, int64_t step) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; return bridge::AtenFromXlaTensor(tensor_methods::slice( @@ -2764,7 +2798,7 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::slice_scatter( const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, c10::optional start, c10::optional end, int64_t step) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto base_ = bridge::GetXlaTensor(base); auto mutated_view_ = bridge::GetXlaTensor(mutated_view); int64_t start_val = start.has_value() ? start.value() : 0; @@ -2793,7 +2827,7 @@ at::Tensor XLANativeFunctions::slice_scatter( at::Tensor XLANativeFunctions::smooth_l1_loss(const at::Tensor& self, const at::Tensor& target, int64_t reduction, double beta) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::smooth_l1_loss( bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction, beta)); @@ -2802,7 +2836,7 @@ at::Tensor XLANativeFunctions::smooth_l1_loss(const at::Tensor& self, at::Tensor XLANativeFunctions::smooth_l1_loss_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, int64_t reduction, double beta) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::smooth_l1_loss_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction, beta)); @@ -2811,7 +2845,7 @@ at::Tensor XLANativeFunctions::smooth_l1_loss_backward( at::Tensor XLANativeFunctions::softplus(const at::Tensor& self, const at::Scalar& beta, const at::Scalar& threshold) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::softplus(bridge::GetXlaTensor(self), beta, threshold)); } @@ -2820,7 +2854,7 @@ at::Tensor XLANativeFunctions::softplus_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& beta, const at::Scalar& threshold) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::softplus_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), beta, threshold)); @@ -2828,7 +2862,7 @@ at::Tensor XLANativeFunctions::softplus_backward(const at::Tensor& grad_output, std::tuple XLANativeFunctions::sort( const at::Tensor& self, int64_t dim, bool descending) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::topk(bridge::GetXlaTensor(self), self.size(dim), dim, descending, /*sorted=*/true, /*stable=*/false); @@ -2839,7 +2873,7 @@ std::tuple XLANativeFunctions::sort( std::tuple XLANativeFunctions::sort( const at::Tensor& self, c10::optional stable, int64_t dim, bool descending) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::topk( bridge::GetXlaTensor(self), self.size(dim), dim, descending, /*sorted=*/false, @@ -2851,7 +2885,7 @@ std::tuple XLANativeFunctions::sort( std::vector XLANativeFunctions::split_copy(const at::Tensor& self, int64_t split_size, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto xla_tensors = tensor_methods::split(bridge::GetXlaTensor(self), split_size, dim); return bridge::AtenFromXlaTensors(xla_tensors); @@ -2859,46 +2893,46 @@ std::vector XLANativeFunctions::split_copy(const at::Tensor& self, std::vector XLANativeFunctions::split_with_sizes_copy( const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto xla_tensors = tensor_methods::split_with_sizes( bridge::GetXlaTensor(self), XlaHelpers::I64List(split_sizes), dim); return bridge::AtenFromXlaTensors(xla_tensors); } at::Tensor XLANativeFunctions::sqrt(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::sqrt(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::squeeze(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::squeeze(bridge::GetXlaTensor(self), dim)); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, at::IntArrayRef dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::squeeze( bridge::GetXlaTensor(self), torch::lazy::ToVector(dim))); } at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::stack(bridge::GetXlaTensors(tensors), dim)); } at::Tensor XLANativeFunctions::std(const at::Tensor& self, bool unbiased) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::std( self_tensor, @@ -2909,7 +2943,7 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, bool unbiased) { at::Tensor XLANativeFunctions::std(const at::Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::std( self_tensor, @@ -2922,7 +2956,7 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, at::OptionalIntArrayRef dim, const c10::optional& correction, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::std( self_tensor, @@ -2934,7 +2968,7 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, std::tuple XLANativeFunctions::std_mean( const at::Tensor& self, at::OptionalIntArrayRef dim, const c10::optional& correction, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto results = tensor_methods::std_mean( self_tensor, @@ -2948,7 +2982,7 @@ std::tuple XLANativeFunctions::std_mean( at::Tensor XLANativeFunctions::sub(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Currently, we disallow the case when both operands contain dynamic // dimensions. This is consistent with PyTorch's behavior. XLA_CHECK(!(tensor_has_dym_dim(self) && tensor_has_dym_dim(other))) @@ -2968,7 +3002,7 @@ at::Tensor XLANativeFunctions::sub(const at::Tensor& self, at::Tensor XLANativeFunctions::sub(const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckSubOperandTypes(self.scalar_type(), GetScalarType(other)); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const at::Scalar& other, @@ -2979,7 +3013,7 @@ at::Tensor XLANativeFunctions::sub(const at::Tensor& self, at::Tensor XLANativeFunctions::sum(const at::Tensor& self, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::sum( self_tensor, @@ -2990,7 +3024,7 @@ at::Tensor XLANativeFunctions::sum(const at::Tensor& self, at::Tensor XLANativeFunctions::sum(const at::Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::sum( self_tensor, @@ -3001,7 +3035,7 @@ at::Tensor XLANativeFunctions::sum(const at::Tensor& self, std::tuple XLANativeFunctions::svd( const at::Tensor& self, bool some, bool compute_uv) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::svd(bridge::GetXlaTensor(self), some, compute_uv); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), @@ -3010,14 +3044,14 @@ std::tuple XLANativeFunctions::svd( } at::Tensor XLANativeFunctions::t_copy(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::transpose(bridge::GetXlaTensor(self), 0, 1)); } at::Tensor XLANativeFunctions::tanh_backward(const at::Tensor& grad_output, const at::Tensor& output) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::tanh_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output))); } @@ -3025,7 +3059,7 @@ at::Tensor XLANativeFunctions::tanh_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::threshold(const at::Tensor& self, const at::Scalar& threshold, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::threshold( bridge::GetXlaTensor(self), threshold.to(), value.to())); } @@ -3033,7 +3067,7 @@ at::Tensor XLANativeFunctions::threshold(const at::Tensor& self, at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& threshold) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::threshold_backward( bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), threshold.to())); @@ -3041,7 +3075,7 @@ at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, std::tuple XLANativeFunctions::topk( const at::Tensor& self, int64_t k, int64_t dim, bool largest, bool sorted) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::topk(bridge::GetXlaTensor(self), k, dim, largest, sorted, /*stable=*/false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), @@ -3049,14 +3083,14 @@ std::tuple XLANativeFunctions::topk( } at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::trace(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, int64_t dim0, int64_t dim1) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::transpose(bridge::GetXlaTensor(self), dim0, dim1)); } @@ -3064,7 +3098,7 @@ at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, std::tuple XLANativeFunctions::triangular_solve( const at::Tensor& b, const at::Tensor& A, bool upper, bool transpose, bool unitriangular) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Currently, ATen doesn't have a left_side option. Once this // is added, this API will have to be changed. auto results = tensor_methods::triangular_solve( @@ -3076,7 +3110,7 @@ std::tuple XLANativeFunctions::triangular_solve( std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensors( tensor_methods::unbind(bridge::GetXlaTensor(self), dim)); } @@ -3084,7 +3118,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, at::Tensor& XLANativeFunctions::uniform_( at::Tensor& self, double from, double to, c10::optional generator) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (generator.has_value() && generator->defined()) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(uniform_)>::call(self, from, to, @@ -3097,7 +3131,7 @@ at::Tensor& XLANativeFunctions::uniform_( at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, int64_t dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::unsqueeze(bridge::GetXlaTensor(self), dim)); } @@ -3105,7 +3139,7 @@ at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, at::Tensor XLANativeFunctions::upsample_bilinear2d( const at::Tensor& self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); absl::Span input_dims = self_tensor->shape().get().dimensions(); @@ -3128,7 +3162,7 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d_backward( const at::Tensor& grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); // Only the XLA TPU backend for now implements the CustomCall required by // our XLA lowering. @@ -3160,7 +3194,7 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d_backward( at::Tensor XLANativeFunctions::upsample_nearest2d( const at::Tensor& self, at::IntArrayRef output_size, c10::optional scales_h, c10::optional scales_w) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); absl::Span input_dims = self_tensor->shape().get().dimensions(); @@ -3183,7 +3217,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward( const at::Tensor& grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h, c10::optional scales_w) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); // Only the XLA TPU backend for now implements the CustomCall required by // our XLA lowering. @@ -3216,7 +3250,7 @@ at::Tensor XLANativeFunctions::var(const at::Tensor& self, at::OptionalIntArrayRef dim, const c10::optional& correction, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(tensor_methods::var( self_tensor, @@ -3229,7 +3263,7 @@ at::Tensor XLANativeFunctions::var(const at::Tensor& self, std::tuple XLANativeFunctions::var_mean( const at::Tensor& self, at::OptionalIntArrayRef dim, const c10::optional& correction, bool keepdim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto results = tensor_methods::var_mean( self_tensor, @@ -3241,7 +3275,7 @@ std::tuple XLANativeFunctions::var_mean( } at::Tensor XLANativeFunctions::view_as_complex_copy(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(self.scalar_type() == at::kFloat || self.scalar_type() == at::kDouble || @@ -3256,7 +3290,7 @@ at::Tensor XLANativeFunctions::view_as_complex_copy(const at::Tensor& self) { } at::Tensor XLANativeFunctions::view_as_real_copy(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(self.is_complex()) << "view_as_real is only supported for complex " "tensors, but got a tensor of scalar type: " @@ -3269,7 +3303,7 @@ at::Tensor XLANativeFunctions::view_as_real_copy(const at::Tensor& self) { at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, at::SymIntArrayRef shape) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); c10::optional int_shape = c10::asIntArrayRefSlowOpt(shape); bool input_shape_static = int_shape.has_value(); XLATensorPtr xla_input = bridge::GetXlaTensor(self); @@ -3286,7 +3320,7 @@ at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, at::Tensor XLANativeFunctions::where(const at::Tensor& condition, const at::Tensor& self, const at::Tensor& other) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); c10::MaybeOwned b_condition, b_self, b_other; std::tie(b_condition, b_self, b_other) = xla_expand_outplace(condition, self, other, "where"); @@ -3296,7 +3330,7 @@ at::Tensor XLANativeFunctions::where(const at::Tensor& condition, } at::Tensor& XLANativeFunctions::zero_(at::Tensor& self) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); tensor_methods::zero_(self_tensor); return self; @@ -3306,7 +3340,7 @@ std::tuple XLANativeFunctions::_linalg_svd( const at::Tensor& self, bool full_matrices, bool compute_uv, c10::optional /* driver */) { // The optional driver string is only for CUDA with a cuSOLVER backend. - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // As per https://pytorch.org/docs/stable/generated/torch.svd.html, // The second boolean argument is exactly opposite between // torch::svd and torch::_linalg_svd, hence the negation of full_matrices. @@ -3371,7 +3405,7 @@ at::Tensor XLANativeFunctions::_cdist_forward( // compute_mode is ignored because the use_mm_for_euclid_dist lowering // (compute_mode is 0 or 1) is achieved through composite ops from // native pytorch. - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0"; return bridge::AtenFromXlaTensor(tensor_methods::cdist_forward( bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p)); @@ -3457,7 +3491,7 @@ XLANativeFunctions::convolution_backward( at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, c10::optional dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr xla_tensor = bridge::GetXlaTensor(self); std::vector dims; if (dim) { @@ -3470,7 +3504,7 @@ at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, at::IntArrayRef dim) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr xla_tensor = bridge::GetXlaTensor(self); std::vector canonical_dims = @@ -3642,7 +3676,7 @@ at::Tensor XLANativeFunctions::mvlgamma(const at::Tensor& self, int64_t p) { at::Tensor XLANativeFunctions::linalg_vector_norm( const at::Tensor& self, const at::Scalar& ord, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(at::isFloatingType(self.scalar_type())) << "Input must be a floating type"; XLATensorPtr self_tensor = bridge::GetXlaTensor(self); @@ -3689,7 +3723,7 @@ at::Tensor XLANativeFunctions::permute(const at::Tensor& self, at::Tensor XLANativeFunctions::as_strided( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); @@ -3707,7 +3741,7 @@ at::Tensor XLANativeFunctions::as_strided( const at::Tensor& XLANativeFunctions::as_strided_( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); @@ -3724,7 +3758,7 @@ const at::Tensor& XLANativeFunctions::as_strided_( at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( tensor_methods::diagonal(bridge::GetXlaTensor(self), offset, dim1, dim2)); } @@ -3732,7 +3766,7 @@ at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset, at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, at::SymIntArrayRef sym_size, bool implicit) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); c10::optional size = c10::asIntArrayRefSlowOpt(sym_size); if (size.has_value()) { return bridge::AtenFromXlaTensor(tensor_methods::expand( @@ -3751,7 +3785,7 @@ at::Tensor XLANativeFunctions::view_symint(const at::Tensor& self, // So only the functionalization version of this function view_copy_symint // support dynamic shape. auto size = C10_AS_INTARRAYREF_SLOW(sym_size); - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::view( bridge::GetXlaTensor(self), XlaHelpers::I64List(size))); } diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index ebd5fef1d06..a920bdb69e9 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -3,6 +3,7 @@ #include #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/tensor_util.h" @@ -102,9 +103,10 @@ xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) { if (from == xla::PrimitiveType::PRED) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); - op = ConvertTo(op, from, - GetDevicePrimitiveType(xla::PrimitiveType::U8, &xla_device), - &xla_device); + op = ConvertTo( + op, from, + MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device), + &xla_device); } return op; } diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp index a4d8840a7c4..3213e4ec5d3 100644 --- a/torch_xla/csrc/convolution.cpp +++ b/torch_xla/csrc/convolution.cpp @@ -348,7 +348,8 @@ xla::XlaOp BuildConvolutionOverrideableBias( xla::XlaOp bias_broadcast = xla::Transpose(xla::Broadcast(bias, broadcast_sizes), BiasTransposePermutation(broadcast_sizes.size() + 1)); - return conv + bias_broadcast; + auto promoted = XlaHelpers::Promote(conv, bias_broadcast); + return promoted.first + promoted.second; } ConvGrads BuildConvolutionBackwardOverrideable( diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 200a750f856..c5d367d963c 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -117,7 +117,7 @@ std::shared_ptr CreateToken( at::Tensor all_reduce(const at::Tensor& self, c10::string_view reduceOp, c10::string_view /*tag*/, at::IntArrayRef /*ranks*/, int64_t /*group_size*/) { - TORCH_LAZY_FN_COUNTER("xla::"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto self_tensor = bridge::GetXlaTensor(self); // TODO(alanwaketan): Use ranks and group_size to generate groups. Currently // we just suse {} as a workaround. Scale is always 1.0 here, and we always diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index b9c14427033..bb02ca7da2e 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -10,6 +10,7 @@ #include "absl/strings/str_join.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/convert_ops.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/reduction.h" #include "torch_xla/csrc/runtime/debug_macros.h" @@ -147,30 +148,51 @@ xla::XlaOp BuildExpand(xla::XlaOp input, xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp scalar) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - int64_t input_rank = input_shape.rank(); const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask); - int64_t mask_rank = mask_shape.rank(); - if (input_rank <= mask_rank) { - input = BuildExpand(input, mask_shape.dimensions()); - } else { - mask = BuildExpand(mask, input_shape.dimensions()); + + if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) { + xla::Shape shape = XlaHelpers::GetPromotedShape(input_shape, mask_shape); + input = BuildExpand(input, shape.dimensions()); + mask = BuildExpand(mask, shape.dimensions()); } + xla::XlaOp zero = xla::Zero(mask.builder(), XlaHelpers::TypeOfXlaOp(mask)); xla::XlaOp mask_pred = xla::Ne(mask, zero); xla::XlaOp update_scalar = ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(), - input_shape.element_type(), nullptr); + ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr); return xla::Select(mask_pred, update_scalar, input); } std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim) { + std::vector squeeze_dims({squeeze_dim}); + return BuildSqueezedDimensions(dimensions, squeeze_dims); +} + +std::vector BuildSqueezedDimensions( + absl::Span dimensions, std::vector& squeeze_dims) { + std::sort(squeeze_dims.begin(), squeeze_dims.end()); std::vector output_dimensions; - for (int64_t i = 0; i < dimensions.size(); ++i) { - int64_t dim = dimensions[i]; - if (dim != 1 || (i != squeeze_dim && squeeze_dim >= 0)) { + size_t i = 0; + for (size_t j = 0; j < dimensions.size(); j++) { + auto dim = dimensions[j]; + if (squeeze_dims.size() == 1 && squeeze_dims[0] == -1) { + // Special case where squeeze_dims = {-1}. + if (dim != 1) { + output_dimensions.push_back(dim); + } + continue; + } + if (i == squeeze_dims.size() || j < squeeze_dims[i]) { + output_dimensions.push_back(dim); + continue; + } + // Checks to see if we need to squeeze the dim or not. + if (dim != 1) { output_dimensions.push_back(dim); } + i++; } return output_dimensions; } @@ -347,7 +369,7 @@ xla::XlaOp BuildUnselect(xla::XlaOp target, xla::XlaOp source, int64_t dim, } xla::PrimitiveType pred_type = - GetDevicePrimitiveType(xla::PrimitiveType::PRED, /*device=*/nullptr); + GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::PRED); xla::XlaOp source_true = XlaHelpers::ScalarBroadcast( 1, pred_type, source_shape.dimensions(), source.builder()); xla::XlaOp pred_zero = xla::Zero(target.builder(), pred_type); diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index e22821a7bb0..067a77abc23 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -49,6 +49,9 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim); +std::vector BuildSqueezedDimensions( + absl::Span dimensions, std::vector& squeeze_dim); + std::vector BuildUnsqueezeDimensions( absl::Span dimensions, int64_t dim); diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 6b6c301b2a1..9959d46f8a2 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -1,9 +1,11 @@ #include "torch_xla/csrc/debug_util.h" #include +#include #include #include +#include #include #include #include @@ -16,7 +18,6 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { @@ -60,7 +61,7 @@ std::string DebugUtil::GetTensorsGraphHlo( absl::Span tensors, const std::vector* indices, bool dump_stablehlo) { std::vector root_values; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -90,7 +91,7 @@ std::string DebugUtil::GetTensorsGraphInfo( std::vector root_nodes; std::vector root_values; std::vector root_hashes; - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { const XLATensorPtr& tensor = tensors[index]; @@ -209,4 +210,93 @@ bool DebugUtil::ExperimentEnabled(const std::string& name) { return xset->find(name) != xset->end(); } +// helper function until we move to C++ 20 +static bool endsWith(const std::string& str, const std::string& suffix) { + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +void DebugUtil::analyze_graph_execution_python_frame( + bool from_dynamo_executation) { + static bool is_master_process = + (runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0); + static std::string debug_file_name = + runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", ""); + static std::string debug_output_prefix = "Execution Analysis: "; + // TODO: Make this configurable. + if (!is_master_process) { + return; + } + std::vector frames = + torch::lazy::GetPythonFrames(); + // python frame must be > 1 + XLA_CHECK_GE(frames.size(), 1); + std::stringstream ss; + ss << "\n" + << debug_output_prefix + << "======================================================================" + "==========" + << "\n"; + ss << debug_output_prefix << "Execution Cause\n"; + if (from_dynamo_executation) { + // when executation is from dynamo compiled graph, the python stack will not + // show any dynamo related python file since frame is already replaced. We + // can either analyze the C++ call stack or rely on caller to pass a boolean + // variable. + ss << debug_output_prefix << " dynamo is executing a compiled program\n"; + } else if (frames[0].function == "mark_step") { + if (frames[1].function == "next" && + endsWith(frames[1].file, "parallel_loader.py")) { + ss << debug_output_prefix + << " mark_step in parallel loader at step end\n"; + } else if (frames[1].function == "__exit__" && + endsWith(frames[1].file, "profiler.py")) { + ss << debug_output_prefix + << " mark_step when exiting a profiler StepTrace region\n"; + } else if ((frames[1].function == "extract_compiled_graph" || + frames[1].function == "extract_internal") && + endsWith(frames[1].file, "dynamo_bridge.py")) { + ss << debug_output_prefix + << " mark_step when dynamo processing input graphs\n"; + } else { + ss << debug_output_prefix << " user mark_step\n"; + } + } else if (frames[0].function == "extract_graph_helper" && + endsWith(frames[0].file, "dynamo_bridge.py")) { + ss << debug_output_prefix << " dynamo is compiling a FX graph to HLO\n"; + } else { + // TODO(JackCaoG): be more specific about exeuction caused by printing + // tensor or fallback or some weird indexing. + ss << debug_output_prefix + << " most likely user code trying to access tensor value before " + "mark_step\n"; + } + + // TODO(JackCaoG): make number of frames printed configurable + ss << debug_output_prefix << "Python Frame Triggered Execution: \n"; + for (auto& location : frames) { + ss << debug_output_prefix << " " << location.function << " (" + << location.file << ":" << location.line << ")\n"; + } + ss << debug_output_prefix + << "----------------------------------------------------------------------" + "----------" + << "\n"; + ss << debug_output_prefix + << "======================================================================" + "==========" + << "\n"; + + // TODO(JackCaoG): print more information about the graph that is about to get + // executed. + if (debug_file_name == "") { + // print to stderr by default + std::cerr << ss.str(); + } else { + std::ofstream outFile; + outFile.open(debug_file_name, std::ios_base::app); + outFile << ss.rdbuf(); + } +} + } // namespace torch_xla diff --git a/torch_xla/csrc/debug_util.h b/torch_xla/csrc/debug_util.h index 2a687207b28..530a45fc83a 100644 --- a/torch_xla/csrc/debug_util.h +++ b/torch_xla/csrc/debug_util.h @@ -46,6 +46,11 @@ class DebugUtil { absl::Span indices); static bool ExperimentEnabled(const std::string& name); + + // warning, this function should only be called when a graph execution is + // about to happen. + static void analyze_graph_execution_python_frame( + bool from_dynamo_executation = false); }; } // namespace torch_xla diff --git a/torch_xla/csrc/device.cpp b/torch_xla/csrc/device.cpp index 791b0271aca..08489c30b8b 100644 --- a/torch_xla/csrc/device.cpp +++ b/torch_xla/csrc/device.cpp @@ -17,6 +17,10 @@ std::string XlaDeviceTypeToString(XlaDeviceType hw_type) { return "CPU"; case XlaDeviceType::GPU: return "GPU"; + case XlaDeviceType::CUDA: + return "CUDA"; + case XlaDeviceType::ROCM: + return "ROCM"; case XlaDeviceType::TPU: return "TPU"; case XlaDeviceType::XPU: @@ -59,6 +63,12 @@ torch::lazy::BackendDevice ParseDeviceString(const std::string& device_spec) { } else if (device_spec_parts[0] == "CPU") { device_type->type = static_cast>(XlaDeviceType::CPU); + } else if (device_spec_parts[0] == "ROCM") { + device_type->type = + static_cast>(XlaDeviceType::ROCM); + } else if (device_spec_parts[0] == "CUDA") { + device_type->type = + static_cast>(XlaDeviceType::CUDA); } else if (device_spec_parts[0] == "GPU") { device_type->type = static_cast>(XlaDeviceType::GPU); diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index c17cbe85540..1dc939bb17a 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -15,7 +15,7 @@ namespace torch_xla { // TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToServer` // until after the paritioning pass. This avoids transfering the full input // tensor to the device. -enum class XlaDeviceType { CPU, GPU, TPU, XPU, NEURON, SPMD }; +enum class XlaDeviceType { CPU, CUDA, ROCM, GPU, TPU, XPU, NEURON, SPMD }; struct DeviceType : public torch::lazy::BackendDeviceType { DeviceType() { type = static_cast(XlaDeviceType::CPU); } diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp new file mode 100644 index 00000000000..918c0ba6515 --- /dev/null +++ b/torch_xla/csrc/dtype.cpp @@ -0,0 +1,219 @@ +#include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/shape.h" + +namespace torch_xla { + +namespace { + +bool ShouldUseBF16() { + bool use_bf16 = runtime::sys_util::GetEnvBool("XLA_USE_BF16", false); + if (use_bf16) { + TF_LOG(INFO) << "Using BF16 data type for floating point values"; + } + return use_bf16; +} + +bool ShouldUseF16() { + bool use_fp16 = runtime::sys_util::GetEnvBool("XLA_USE_FP16", false); + if (use_fp16) { + TF_LOG(INFO) << "Using F16 data type for floating point values"; + } + return use_fp16; +} + +bool ShouldDowncastToBF16() { + bool downcast_bf16 = + runtime::sys_util::GetEnvBool("XLA_DOWNCAST_BF16", false); + if (downcast_bf16) { + TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->BF16"; + } + return downcast_bf16; +} + +bool ShouldDowncastToF16() { + bool downcast_fp16 = + runtime::sys_util::GetEnvBool("XLA_DOWNCAST_FP16", false); + if (downcast_fp16) { + TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->FP16"; + } + return downcast_fp16; +} + +bool ShouldUse32BitLong() { + bool use_32bit_long = + runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false); + if (use_32bit_long) { + TF_LOG(INFO) << "Using 32bit integers for kLong values"; + } + return use_32bit_long; +} + +bool UseBF16() { + static bool use_bf16 = ShouldUseBF16(); + return use_bf16; +} + +bool UseF16() { + static bool use_fp16 = ShouldUseF16(); + return use_fp16; +} + +bool DowncastBF16() { + static bool downcast_bf16 = ShouldDowncastToBF16(); + return downcast_bf16; +} + +bool DowncastF16() { + static bool downcast_fp16 = ShouldDowncastToF16(); + return downcast_fp16; +} + +bool Use32BitLong() { + static bool use_32bit_long = ShouldUse32BitLong(); + return use_32bit_long; +} + +bool IsTpuDevice(XlaDeviceType hw_type) { + static bool spmd_device_is_tpu = + (hw_type == XlaDeviceType::SPMD) && + // HACK: find a better way to decide if SPMD is actually a TPU without + // accessing the runtime. + runtime::sys_util::GetEnvString("PJRT_DEVICE", "") == "TPU"; + return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; +} + +} // namespace + +at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { + switch (xla_type) { + case xla::PrimitiveType::BF16: + return at::ScalarType::BFloat16; + case xla::PrimitiveType::F16: + return at::ScalarType::Half; + case xla::PrimitiveType::F32: + return at::ScalarType::Float; + case xla::PrimitiveType::F64: + return at::ScalarType::Double; + case xla::PrimitiveType::PRED: + return at::ScalarType::Bool; + case xla::PrimitiveType::U8: + return at::ScalarType::Byte; + case xla::PrimitiveType::S8: + return at::ScalarType::Char; + case xla::PrimitiveType::S16: + case xla::PrimitiveType::U16: + return at::ScalarType::Short; + case xla::PrimitiveType::S32: + case xla::PrimitiveType::U32: + return at::ScalarType::Int; + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U64: + return at::ScalarType::Long; + case xla::PrimitiveType::C64: + return at::ScalarType::ComplexFloat; + case xla::PrimitiveType::C128: + return at::ScalarType::ComplexDouble; + default: + XLA_ERROR() << "XLA type not supported: " << xla_type; + } +} + +xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) { + switch (scalar_type) { + case at::ScalarType::Double: + return xla::PrimitiveType::F64; + case at::ScalarType::Float: + return xla::PrimitiveType::F32; + case at::ScalarType::BFloat16: + return xla::PrimitiveType::BF16; + case at::ScalarType::Half: + return xla::PrimitiveType::F16; + case at::ScalarType::Bool: + return xla::PrimitiveType::PRED; + case at::ScalarType::Byte: + return xla::PrimitiveType::U8; + case at::ScalarType::Char: + return xla::PrimitiveType::S8; + case at::ScalarType::Short: + return xla::PrimitiveType::S16; + case at::ScalarType::Int: + return xla::PrimitiveType::S32; + case at::ScalarType::Long: + return xla::PrimitiveType::S64; + case at::ScalarType::ComplexFloat: + return xla::PrimitiveType::C64; + case at::ScalarType::ComplexDouble: + return xla::PrimitiveType::C128; + default: + XLA_ERROR() << "Type not supported: " << scalar_type; + } +} + +xla::PrimitiveType MaybeDowncastToXlaDeviceType( + xla::PrimitiveType type, const torch::lazy::BackendDevice& device) { + XlaDeviceType hw_type = static_cast(device.type()); + switch (type) { + case xla::PrimitiveType::F64: + if (UseF16()) { + return xla::PrimitiveType::F16; + } + if (UseBF16()) { + return xla::PrimitiveType::BF16; + } + if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) || + hw_type == XlaDeviceType::NEURON) { + return xla::PrimitiveType::F32; + } + return xla::PrimitiveType::F64; + case xla::PrimitiveType::F32: + if (UseF16() || DowncastF16()) { + return xla::PrimitiveType::F16; + } + return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 + : xla::PrimitiveType::F32; + case xla::PrimitiveType::U16: + return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON + ? xla::PrimitiveType::U16 + : xla::PrimitiveType::U32; + case xla::PrimitiveType::S16: + return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON + ? xla::PrimitiveType::S16 + : xla::PrimitiveType::S32; + case xla::PrimitiveType::S64: + return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; + case xla::PrimitiveType::U64: + return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; + case xla::PrimitiveType::C128: + return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128 + : xla::PrimitiveType::C64; + default: + return type; + } +} + +xla::PrimitiveType MaybeDowncastToXlaDeviceType( + at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) { + xla::PrimitiveType xla_type = XlaTypeFromTorchType(scalar_type); + return MaybeDowncastToXlaDeviceType(xla_type, device); +} + +at::ScalarType MaybeUpcastToHostTorchType(xla::PrimitiveType xla_type) { + at::ScalarType scalar_type = TorchTypeFromXlaType(xla_type); + switch (scalar_type) { + case at::ScalarType::BFloat16: + return UseBF16() || DowncastBF16() ? at::ScalarType::Float + : at::ScalarType::BFloat16; + case at::ScalarType::Half: + return UseF16() || DowncastF16() ? at::ScalarType::Float + : at::ScalarType::Half; + case at::ScalarType::Float: + return DowncastBF16() || DowncastF16() ? at::ScalarType::Double + : at::ScalarType::Float; + default: + return scalar_type; + } +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/dtype.h b/torch_xla/csrc/dtype.h new file mode 100644 index 00000000000..399c67fcbde --- /dev/null +++ b/torch_xla/csrc/dtype.h @@ -0,0 +1,25 @@ +#ifndef XLA_TORCH_XLA_CSRC_DTYPE_H_ +#define XLA_TORCH_XLA_CSRC_DTYPE_H_ + +#include "torch_xla/csrc/device.h" +#include "xla/shape.h" + +namespace torch_xla { + +at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type); + +xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type); + +// Downcast type to be compatible with device if necessary. +xla::PrimitiveType MaybeDowncastToXlaDeviceType( + xla::PrimitiveType type, const torch::lazy::BackendDevice& device); + +xla::PrimitiveType MaybeDowncastToXlaDeviceType( + at::ScalarType scalar_type, const torch::lazy::BackendDevice& device); + +// Upcast type to original PyTorch type. +at::ScalarType MaybeUpcastToHostTorchType(xla::PrimitiveType xla_type); + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_DTYPE_H_ diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 132bba18f95..62a906d98ea 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -5,6 +5,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/random.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" @@ -66,8 +67,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - return xla::Max(input, XlaHelpers::ScalarValue( - 0, input_shape.element_type(), input.builder())); + xla::XlaOp scalar = XlaHelpers::ScalarValue( + 0, input_shape.element_type(), input.builder()); + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + // xla::Max doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scalar); + return xla::Max(promoted.first, promoted.second); + } else { + return xla::Max(input, scalar); + } } xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) { @@ -235,6 +244,19 @@ xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight) { return xla::Select(xla::Gt(input, zero), input, product); } +std::vector BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input, + xla::XlaOp weight) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + const xla::Shape& weight_shape = ShapeHelper::ShapeOfXlaOp(weight); + + xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type()); + xla::XlaOp grad_input = xla::Mul(weight, grad); + xla::XlaOp grad_weight = xla::Mul(input, grad); + + return {xla::Select(xla::Gt(input, zero), grad, grad_input), + xla::Select(xla::Gt(input, zero), zero, grad_weight)}; +} + xla::XlaOp BuildSigmoid(xla::XlaOp input) { return xla::Logistic(input); } xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) { diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 0753925e82f..1b327e1bfc1 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -21,6 +21,9 @@ xla::XlaOp BuildRelu(xla::XlaOp input); xla::XlaOp BuildPrelu(xla::XlaOp input, xla::XlaOp weight); +std::vector BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input, + xla::XlaOp weight); + std::vector BuildRrelu(xla::XlaOp input, const at::Scalar& lower, const at::Scalar& upper, bool training, xla::XlaOp rng_seed); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 3f25ca1bb1d..995e43078b6 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -7,8 +7,8 @@ #include "absl/strings/str_join.h" #include "torch_xla/csrc/convert_ops.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" @@ -20,6 +20,9 @@ namespace torch_xla { namespace { +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. +static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::XlaOp result) { xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1); @@ -62,6 +65,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + XLA_CHECK(sizes[i] != kUnboundedSize); + } } return xla::BroadcastInDim(input, bcast_sizes, GetAllDimensions(bcast_sizes.size())); @@ -321,6 +327,59 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, : xla::Reshape(input, shape.dimensions()); } +bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; + const absl::Span dims = shape.dimensions(); + return std::any_of(dims.begin(), dims.end(), + [](int64_t size) { return size == kUnboundedSize; }); +} + +xla::XlaOp XlaHelpers::DynamicUnboundedReshape( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes) { + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; + const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); + XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) + << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; + std::vector get_dim_ops; + std::vector reshaped_ops; + bool all_static = true; + std::vector output_dynamic(output_sizes.size(), false); + + for (int i = 0; i < output_sizes.size(); i++) { + if (output_sizes[i] == kUnboundedSize) { + output_dynamic[i] = true; + get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i)); + all_static = false; + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + output_sizes[i], aux_input.builder())); + } + } + + if (all_static) { + return xla::Reshape(input, output_sizes); + } + + // Create the reshape from scalar to 1-D vector + for (auto get_dim_op : get_dim_ops) { + reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); + } + + // Create Concatenate op + auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0}); + return xla::CustomCall( + aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op}, + xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes, + output_dynamic)); + + return input; +} + bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, const xla::Shape& shape2) { return shape1.is_static() && shape2.is_static() && @@ -484,6 +543,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && + !XlaHelpers::IsUnboundedDynamic(shape2)) + << "Unreachable for unbounded dynamic code\n"; + } return GetPromotedDynamicShape(shape1, shape2); } @@ -681,7 +745,7 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( } torch::lazy::Shape XlaHelpers::ConvertXlaShapeToLazy(const xla::Shape& shape) { - at::ScalarType scalar_type = TensorTypeFromXlaType(shape.element_type()); + at::ScalarType scalar_type = MaybeUpcastToHostTorchType(shape.element_type()); c10::optional> is_symbolic = c10::nullopt; if (shape.is_dynamic()) { std::vector xla_dynamic_dimensions = diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 817566159ed..66c01588b57 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -13,6 +13,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/util.h" #include "tsl/platform/bfloat16.h" #include "xla/client/xla_builder.h" @@ -158,6 +159,17 @@ class XlaHelpers { static xla::XlaOp DynamicReshape(xla::XlaOp input, absl::Span output_sizes); + static bool IsUnboundedDynamic(const xla::Shape& shape); + + static bool IsUnboundedDynamismEnabled() { + return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", + false); + } + + static xla::XlaOp DynamicUnboundedReshape( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes); + static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); static bool SameStaticDimensions(const xla::Shape& shape1, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b3957d7a68f..4758579bbb6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -20,6 +20,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/variant.h" #include "pybind11/attr.h" #include "pybind11/cast.h" @@ -29,23 +30,25 @@ #include "pybind11/pytypes.h" #include "pybind11/stl_bind.h" #include "torch_xla/csrc/XLANativeFunctions.h" +#include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" +#include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_impl.h" @@ -94,7 +97,6 @@ void PrepareToExit() { runtime::GetComputationClientIfInitialized(); if (client != nullptr) { XLAGraphExecutor::Get()->WaitDeviceOps({}); - client->PrepareToExit(); } } @@ -800,7 +802,7 @@ class PyLoweringContext { xla::Literal& literal = literals[i]; xla::XlaOp op = lowering_ctx.GetParameter(device_data[i]); at::ScalarType dtype = - TensorTypeFromXlaType(literal.shape().element_type()); + MaybeUpcastToHostTorchType(literal.shape().element_type()); at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); results[param_ids[i]] = input; } @@ -1559,72 +1561,39 @@ void InitXlaModuleBindings(py::module m) { tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); })); - m.def("_xla_mark_sharding", [](const at::Tensor& input, - xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } - } + m.def("_xla_mark_sharding", + [](const at::Tensor& input, xla::OpSharding sharding) { + ShardingUtil::xla_mark_sharding(input, sharding); + }); + m.def("_xla_mark_sharding_dynamo_custom_op", + [](const at::Tensor& input, const py::list& tile_assignment, + const py::list& group_assignment, const py::list& replication_groups, + int sharding_type) { + c10::List tile_assignment_list = + c10::List(); + for (auto t : tile_assignment) { + tile_assignment_list.push_back( + at::IntArrayRef(t.cast>())); + } - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } + c10::List group_assignment_list = + c10::List(); + for (auto t : group_assignment) { + group_assignment_list.push_back( + at::IntArrayRef(t.cast>())); + } - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; - } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); + c10::List replication_groups_list = + c10::List(); + for (auto t : replication_groups) { + replication_groups_list.push_back( + at::IntArrayRef(t.cast>())); + } - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); - }); + xla_mark_sharding_dynamo_custom_op( + input, tile_assignment_list, group_assignment_list, + replication_groups_list, sharding_type); + }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->ClearShardingSpec(); @@ -1663,6 +1632,72 @@ void InitXlaModuleBindings(py::module m) { } return std::nullopt; }); + // Reassemble the CPU shards into a global tensor. A new sharded tensor is + // created from the local shards with the provided sharding annotation + // attached. The order of the shards should coincide with the order of + // devices returned by `torch_xla.runtime.local_runtime_devices()`. + m.def( + "_global_tensor_from_cpu_shards", + [](const std::vector& shards, const xla::OpSharding& sharding, + std::optional>& global_shape) -> at::Tensor { + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); + XLA_CHECK(local_devices.size() == shards.size()) + << "Must specify a shard for each local device"; + XLA_CHECK(!global_shape.has_value() || + global_shape.value().size() == shards[0].sizes().size()) + << "Global shape rank must agree with shard rank: expected rank " + << shards[0].sizes().size() << ", got " + << global_shape.value().size(); + + if (!global_shape.has_value()) { + // Set a default value for the global shape based on the sharding + // type. + if (sharding.type() == xla::OpSharding::OTHER) { + // Infer the global shape to be the shard shape scaled by the tiling + // dimensionality. + auto tile_shape = sharding.tile_assignment_dimensions(); + global_shape = std::vector(); + for (int dim = 0; dim < shards[0].sizes().size(); ++dim) { + auto global_dim = tile_shape[dim] * shards[0].sizes()[dim]; + global_shape->push_back(global_dim); + } + } else if (sharding.type() == xla::OpSharding::REPLICATED) { + global_shape = shards[0].sizes().vec(); + } else { + XLA_ERROR() << "Unsupported OpSharding type: " << sharding.type(); + } + } + + auto device = GetVirtualDevice(); + auto primitive_type = + MakeXlaPrimitiveType(shards[0].type().scalarType(), &device); + xla::Shape tensor_shape = MakeArrayShapeFromDimensions( + global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type, + static_cast(device.type())); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + + // Verify that the shard shape is correct for the global shape and + // sharding spec. + auto expected_shard_shape = ShardingUtil::GetShardShape(sharding_spec); + for (auto shard : shards) { + XLA_CHECK(shard.sizes() == expected_shard_shape) + << "Input shard shape must include padding: " << shard.sizes() + << " vs " << expected_shard_shape; + } + + auto data_handle = ShardingUtil::CreateShardedData( + shards, local_devices, sharding_spec); + XLATensorPtr xla_tensor = XLATensor::Create(std::move(data_handle)); + xla_tensor->SetShardingSpec(*sharding_spec); + auto tensor = bridge::AtenFromXlaTensor(std::move(xla_tensor)); + return torch::autograd::make_variable(tensor, + shards[0].requires_grad()); + }, + py::arg("shards"), py::arg("sharding"), + py::arg("global_shape") = py::none()); // Returns the local shards of the tensor, with values taken from the // underlying ComputationClient::GetDataShards. As such, the shards will // contain any padding that was applied to ensure they all have the same @@ -1691,9 +1726,9 @@ void InitXlaModuleBindings(py::module m) { for (const runtime::ComputationClient::DataPtr shard_handle : shard_handles) { shards.push_back( - XlaDataToTensors( - {shard_handle}, - TensorTypeFromXlaType(shard_handle->shape().element_type())) + XlaDataToTensors({shard_handle}, + MaybeUpcastToHostTorchType( + shard_handle->shape().element_type())) .front()); str_devices.push_back(shard_handle->device()); } @@ -1810,6 +1845,45 @@ void InitXlaModuleBindings(py::module m) { xla::HloModule::CreateFromProto(module_proto, config).value()); return module->ToString(); }); + // Initialize the XlaCoordinator in the runtime if not already initialized. + m.def("_ensure_xla_coordinator_initialized", + [](int global_rank, int world_size, std::string master_addr, + std::string master_port) { + auto comp_client = runtime::GetComputationClient(); + if (!comp_client->CoordinatorInitialized()) { + runtime::GetComputationClient()->InitializeCoordinator( + global_rank, world_size, master_addr, master_port); + } + }, + py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"), + py::arg("master_port") = + runtime::XlaCoordinator::kDefaultCoordinatorPort); + // Create a PreemptionSyncManager for the XlaCoordinator. The + // PreemptionSyncManager will register a SIGTERM handler as a side effect. + m.def("_activate_preemption_sync_manager", []() { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + coordinator.ActivatePreemptionSyncManager(); + }); + // Deactivate the PreemptionSyncManager in the XlaCoordinator if one is active + m.def("_deactivate_preemption_sync_manager", []() { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + coordinator.DeactivatePreemptionSyncManager(); + }); + // Check whether a sync point has been reached. This method requires that the + // distributed runtime be initialized and a PreemptionSyncManager activated. + m.def("_sync_point_reached", [](int step) { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + return coordinator.ReachedSyncPoint(step); + }); m.def("_is_placecholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); return xtensor->CurrentDataHandle() && @@ -1848,30 +1922,6 @@ void InitXlaModuleBindings(py::module m) { SetAllReduceToken(device, token); }); - /* The distributed runtime service is used by the PjRt GPU client. */ - py::class_> - distributed_runtime_service(m, "DistributedRuntimeService"); - distributed_runtime_service.def("shutdown", - &xla::DistributedRuntimeService::Shutdown, - py::call_guard()); - m.def("_xla_get_distributed_runtime_service", - [](int num_nodes) -> std::unique_ptr { - std::string dist_service_addr = - runtime::sys_util::GetEnvString("PJRT_DIST_SERVICE_ADDR", ""); - XLA_CHECK(!dist_service_addr.empty()) - << "Must set PJRT_DIST_SERVICE_ADDR environment variable to use " - "distributed runtime"; - XLA_CHECK(num_nodes > 0) - << "num_nodes must be positive: " << num_nodes; - - xla::CoordinationServiceImpl::Options options; - options.num_nodes = num_nodes; - return std::move( - xla::GetDistributedRuntimeService(dist_service_addr, options) - .value()); - }); - BuildProfilerSubmodule(&m); BuildLoweringContextSubmodule(&m); @@ -1885,6 +1935,12 @@ void InitXlaModuleBindings(py::module m) { return handles; }); + m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { + TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->MarkDynamicDimension(dim); + }); + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 0fc2d77a47f..82b746ab181 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -174,6 +174,8 @@ xla::Shape XlaNode::GetOpShape( std::string XlaNode::ToString() const { std::stringstream ss; ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_; + ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ") + << ')'; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index c63fe289b9d..d0619ef5c98 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -9,9 +9,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -138,6 +138,17 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; + void MarkDynamicDimension(uint32_t dim) { + unbounded_dynamic_dims_.insert(dim); + } + + const std::unordered_set& dynamic_dims() const { + return unbounded_dynamic_dims_; + } + + protected: + std::unordered_set unbounded_dynamic_dims_; + private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index 3c54b33e911..b488acbaefc 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -41,7 +41,8 @@ class LayoutManager { struct DimensionsHasher { size_t operator()(const absl::Span& dimensions) const { - return runtime::util::HashReduce(runtime::util::MHash(dimensions)); + return torch::lazy::HashReduce(torch::lazy::MHash( + std::vector({dimensions.begin(), dimensions.end()}))); } }; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index dbb1fd69cb3..404fa82ea7b 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,19 +93,31 @@ LoweringContext::LoweringContext( } } +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. +static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data) { + const std::shared_ptr& data, + const std::unordered_set& unbounded_dynamic_dims) { torch::lazy::BackendData::Handle handle = data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - xla::XlaOp param = xla::Parameter( - builder(), parameters_.size(), + xla::Shape shape = std::dynamic_pointer_cast(data) - ->shape(), - absl::StrCat("p", parameters_.size())); + ->shape(); + for (const int dim : unbounded_dynamic_dims) { + shape.set_dynamic_dimension(dim, true); + shape.set_dimensions(dim, kUnboundedSize); + } + xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape, + absl::StrCat("p", parameters_.size())); it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) .first; parameters_.push_back(data); + } else { + XLA_CHECK(unbounded_dynamic_dims.empty()) + << "The unbounded dynamic dims can only be set when Parameter is " + "created."; } parameter_sequence_.push_back(it->second.index); return it->second.param; @@ -170,6 +182,22 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { const XlaNode* casted = dynamic_cast(node); result_ops = casted->Lower(this); + if (!casted->dynamic_dims().empty()) { + xla::internal::XlaBuilderFriend builder_friend; + auto* inst = builder_friend.GetInstruction(result_ops[0]); + auto* mutable_dynamic = + inst->mutable_shape()->mutable_is_dynamic_dimension(); + if (mutable_dynamic->empty()) { + for (int i = 0; i < inst->dimensions_size(); i++) { + mutable_dynamic->Add(false); + } + } + auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); + for (const auto dim : casted->dynamic_dims()) { + mutable_dynamic->Set(dim, true); + mutable_dims->Set(dim, kUnboundedSize); + } + } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); } diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 76684326326..b46d91874b0 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -37,7 +37,8 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter( - const std::shared_ptr& data); + const std::shared_ptr& data, + const std::unordered_set& dynamic_dims = {}); // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. diff --git a/torch_xla/csrc/ops/adam_optimizer_step.cpp b/torch_xla/csrc/ops/adam_optimizer_step.cpp index d5b6d03b910..dd12df326a0 100644 --- a/torch_xla/csrc/ops/adam_optimizer_step.cpp +++ b/torch_xla/csrc/ops/adam_optimizer_step.cpp @@ -29,7 +29,7 @@ AdamOptimizerStep::AdamOptimizerStep( {found_inf, step, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1, beta2, lr, weight_decay, eps}, NodeOutputShape(step, param), - /*num_outputs=*/5, + /*num_outputs=*/(use_amsgrad ? 5 : 4), torch::lazy::MHash(use_weight_decay, use_amsgrad, use_adamw)), use_weight_decay_(use_weight_decay), use_amsgrad_(use_amsgrad), diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index 80b2418776b..95068640a27 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -3,6 +3,7 @@ #include #include "torch_xla/csrc/convert_ops.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" @@ -51,8 +52,8 @@ XlaOpVector Cast::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); xla::PrimitiveType raw_from = - stype_ ? TensorTypeToRawXlaType(*stype_) : input_shape.element_type(); - xla::PrimitiveType raw_to = dtype_ ? TensorTypeToRawXlaType(*dtype_) : type_; + stype_ ? XlaTypeFromTorchType(*stype_) : input_shape.element_type(); + xla::PrimitiveType raw_to = dtype_ ? XlaTypeFromTorchType(*dtype_) : type_; xla::XlaOp output = ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to, /*device=*/nullptr); diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index 07956843a7d..e07fe3c4e76 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const { } XlaOpVector DeviceData::Lower(LoweringContext* loctx) const { - return ReturnOp(loctx->GetParameter(data_), loctx); + return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx); } DeviceData* DeviceData::Cast(const torch::lazy::Node* node) { diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 933149040be..ae3fb83d54b 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -132,6 +132,25 @@ torch::lazy::NodePtr Prelu(const torch::lazy::Value& input, GetXlaShape(input), std::move(lower_fn)); } +torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad, + const torch::lazy::Value& input, + const torch::lazy::Value& weight) { + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_grad = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_weight = loctx->GetOutputOp(node.operand(2)); + return node.ReturnOps(BuildPreluBackward(xla_grad, xla_input, xla_weight), + loctx); + }; + + return GenericOp( + torch::lazy::OpKind(at::aten::_prelu_kernel_backward), + {grad, input, weight}, + xla::ShapeUtil::MakeTupleShape({GetXlaShape(grad), GetXlaShape(input)}), + std::move(lower_fn), /*num_outputs=*/2); +} + torch::lazy::NodePtr LogSigmoid(const torch::lazy::Value& input) { auto lower_fn = [](const XlaNode& node, LoweringContext* loctx) -> XlaOpVector { diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index d8015ed5f63..c110fd3a32c 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -101,6 +101,10 @@ torch::lazy::NodePtr Sqrt(const torch::lazy::Value& input); torch::lazy::NodePtr Prelu(const torch::lazy::Value& input, const torch::lazy::Value& weight); +torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad, + const torch::lazy::Value& input, + const torch::lazy::Value& weight); + torch::lazy::NodePtr Pow(const torch::lazy::Value& input, const torch::lazy::Value& exponent); diff --git a/torch_xla/csrc/random.cpp b/torch_xla/csrc/random.cpp index 44564c1d24f..9624c160ad0 100644 --- a/torch_xla/csrc/random.cpp +++ b/torch_xla/csrc/random.cpp @@ -21,6 +21,8 @@ std::string GetDefaultGitGeneratorName() { static_cast(bridge::GetCurrentDevice().type()); switch (hw_type) { case XlaDeviceType::GPU: + case XlaDeviceType::CUDA: + case XlaDeviceType::ROCM: return "three_fry"; default: return "default"; @@ -62,7 +64,8 @@ xla::XlaOp MakeUniformBoundaryValue(xla::XlaOp val, bool downcast = false) { xla::PrimitiveType element_type = XlaHelpers::TypeOfXlaOp(val); if (element_type == xla::PrimitiveType::BF16 || element_type == xla::PrimitiveType::F16) { - auto dtype = downcast ? xla::PrimitiveType::F16 : xla::PrimitiveType::F32; + // Use BF16 if `downcast` is set. + auto dtype = downcast ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; return xla::ConvertElementType(val, dtype); } else if (xla::primitive_util::IsComplexType(element_type)) { return xla::Real(val); @@ -75,7 +78,9 @@ xla::Shape MakeRngShape(const xla::Shape& shape, bool downcast = false) { xla::Shape rng_shape(shape); if (element_type == xla::PrimitiveType::BF16 || element_type == xla::PrimitiveType::F16) { - auto dtype = downcast ? xla::PrimitiveType::F16 : xla::PrimitiveType::F32; + // This controls the bit width and we use 8-bit if `downcast` is set. + auto dtype = + downcast ? xla::PrimitiveType::F8E5M2 : xla::PrimitiveType::F32; rng_shape.set_element_type(dtype); } else if (xla::primitive_util::IsComplexType(element_type)) { rng_shape.set_element_type( diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 8a7ab98d5d5..aff46743410 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -81,7 +81,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, xla::XlaOp scale = xla::Select(xla::Ne(count, zero), one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); - return input * scale; + + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + // XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scale); + return promoted.first * promoted.second; + } else { + return input * scale; + } } xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) { @@ -109,8 +117,15 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { - result.result = - XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + // TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is + // added. + result.result = XlaHelpers::DynamicUnboundedReshape( + result.result, input, result.rinfo.new_dimensions); + } else { + result.result = XlaHelpers::DynamicReshape(result.result, + result.rinfo.new_dimensions); + } } return result; } @@ -399,8 +414,8 @@ xla::XlaOp BuildArgMax(xla::XlaOp input, int64_t dim, bool keepdim) { shape = &ShapeHelper::ShapeOfXlaOp(operand); } xla::XlaOp result = xla::ArgMax( - operand, - GetDevicePrimitiveType(xla::PrimitiveType::S64, /*device=*/nullptr), dim); + operand, GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::S64), + dim); if (keepdim) { auto dimensions = torch::lazy::ToVector(shape->dimensions()); dimensions[dim] = 1; @@ -419,8 +434,8 @@ xla::XlaOp BuildArgMin(xla::XlaOp input, int64_t dim, bool keepdim) { shape = &ShapeHelper::ShapeOfXlaOp(operand); } xla::XlaOp result = xla::ArgMin( - operand, - GetDevicePrimitiveType(xla::PrimitiveType::S64, /*device=*/nullptr), dim); + operand, GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::S64), + dim); if (keepdim) { auto dimensions = torch::lazy::ToVector(shape->dimensions()); dimensions[dim] = 1; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index cf9b2b6d1ab..d705ea0bdc5 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -1,32 +1,22 @@ +load( + "//bazel:rules_def.bzl", + "ptxla_cc_test", +) + load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load( + "//bazel:rules_def.bzl", + "ptxla_cc_test", +) + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -cc_library( - name = "async_task", - hdrs = ["async_task.h"], - deps = [ - ":debug_macros", - ":thread_pool", - "@com_google_absl//absl/types:optional", - ], -) - -cc_test( - name = "async_task_test", - size = "small", - srcs = ["async_task_test.cc"], - deps = [ - ":async_task", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "runtime", srcs = [ @@ -61,9 +51,12 @@ cc_library( ":metrics_reader", ":metrics", ":sys_util", + ":tensor_source", ":types", ":util", + ":xla_coordinator", "//torch_xla/csrc:device", + "//torch_xla/csrc:dtype", "@tsl//tsl/platform:stacktrace_handler", "@xla//xla:literal_util", "@xla//xla/client:xla_computation", @@ -89,10 +82,12 @@ cc_library( ":computation_client", ":debug_macros", ":env_vars", - ":multi_wait", + ":profiler", ":stablehlo_helper", + ":tensor_source", ":tf_logging", - ":thread_pool", + ":xla_coordinator", + "//torch_xla/csrc:thread_pool", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", @@ -102,9 +97,11 @@ cc_library( "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt:tfrt_cpu_pjrt_client", "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/platform/cloud:gcs_file_system", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) @@ -163,6 +160,18 @@ cc_library( ], ) +cc_library( + name = "xla_coordinator", + srcs = ["xla_coordinator.cc"], + hdrs = ["xla_coordinator.h"], + deps = [ + ":debug_macros", + ":sys_util", + "@xla//xla/pjrt/distributed", + "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", + ], +) + cc_library( name = "metrics", srcs = ["metrics.cc"], @@ -178,27 +187,17 @@ cc_library( ], ) +# Profiler silently fails unless we link these backends cc_library( - name = "multi_wait", - srcs = ["multi_wait.cc"], - hdrs = ["multi_wait.h"], - deps = [ - "@xla//xla:types", - ], -) - -cc_library( - name = "nccl_distributed", - srcs = ["nccl_distributed.cc"], - hdrs = ["nccl_distributed.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@xla//xla:types", - ] + if_cuda_is_configured([ - "@local_config_nccl//:nccl", - ]), + name = "profiler_backends", + visibility = ["//visibility:private"], + deps = [ + "@xla//xla/backends/profiler/cpu:host_tracer", + "@xla//xla/backends/profiler/cpu:metadata_collector", + ] + if_cuda_is_configured([ + "@xla//xla/backends/profiler/gpu:device_tracer", + ]), + alwayslink = True, ) cc_library( @@ -206,14 +205,17 @@ cc_library( srcs = ["profiler.cc"], hdrs = ["profiler.h"], deps = [ + ":debug_macros", + ":profiler_backends", + "@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs", + "@xla//xla/backends/profiler/plugin:plugin_tracer", + "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@tsl//tsl/platform:status", + "@tsl//tsl/profiler/lib:profiler_factory", "@tsl//tsl/profiler/rpc:profiler_server_impl", "@tsl//tsl/profiler/rpc/client:capture_profile", "@com_google_absl//absl/container:flat_hash_map", - # Profiler silently fails unless we include this - "@xla//xla/backends/profiler:profiler_backends", - # TODO: We get missing symbol errors without these deps. Why aren't they # included transitively from TensorFlow/TSL? "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", @@ -269,13 +271,14 @@ cc_library( ) cc_library( - name = "thread_pool", - srcs = ["thread_pool.cc"], - hdrs = ["thread_pool.h"], - deps = [ - ":metrics", - ":tf_logging", - ], + name = "tensor_source", + hdrs = ["tensor_source.h"], + deps = [ + ":debug_macros", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@torch//:headers", + ] ) cc_library( @@ -288,18 +291,8 @@ cc_library( ], ) -cc_library( - name = "unique", - hdrs = ["unique.h"], - deps = [ - ":debug_macros", - "@com_google_absl//absl/types:optional", - ], -) - cc_library( name = "util", - srcs = ["util.cc"], hdrs = ["util.h"], deps = [ ":types", @@ -342,10 +335,11 @@ cc_library( "@xla//xla/service:platform_util", "@xla//xla/service/spmd:spmd_partitioner", "@tsl//tsl/platform:errors", + "@torch//:headers", ], ) -cc_test( +ptxla_cc_test( name = "xla_util_test", size = "small", srcs = ["xla_util_test.cc"], @@ -362,25 +356,27 @@ cc_test( ], ) -# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream. -# xla_cc_test( -# name = "pjrt_computation_client_test", -# srcs = ["pjrt_computation_client_test.cc"], -# deps = [ -# ":computation_client", -# "@xla//xla:literal", -# "@xla//xla:literal_util", -# "@xla//xla:shape_util", -# "@xla//xla:status", -# "@xla//xla:statusor", -# "@xla//xla/client:xla_builder", -# "@xla//xla/client:xla_computation", -# "@xla//xla/tests:literal_test_util", -# "@xla//xla/tools:hlo_module_loader", -# "@org_tensorflow//tensorflow/core/platform:logging", -# "@tsl//tsl/lib/core:status_test_util", -# "@tsl//tsl/platform:env", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_main", -# ], -# ) +ptxla_cc_test( + name = "pjrt_computation_client_test", + srcs = ["pjrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":pjrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/torch_xla/csrc/runtime/async_task.h b/torch_xla/csrc/runtime/async_task.h deleted file mode 100644 index 73d923e0eb2..00000000000 --- a/torch_xla/csrc/runtime/async_task.h +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef XLA_CLIENT_ASYNC_TASK_H_ -#define XLA_CLIENT_ASYNC_TASK_H_ - -#include -#include -#include -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/thread_pool.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -template -class AsyncTask { - struct Data { - Data(std::function taskfn) : taskfn(std::move(taskfn)) {} - - std::function taskfn; - std::mutex mutex; - std::condition_variable cv; - bool scheduled = false; - bool completed = false; - absl::optional result; - std::exception_ptr exptr; - }; - - public: - explicit AsyncTask(std::function taskfn) - : data_(std::make_shared(std::move(taskfn))) {} - - AsyncTask& Wait() { - std::unique_lock lock(data_->mutex); - XLA_CHECK(data_->scheduled); - data_->cv.wait(lock, [this] { return data_->completed; }); - if (data_->exptr != nullptr) { - std::rethrow_exception(data_->exptr); - } - return *this; - } - - AsyncTask& Schedule() { - auto completer = [data = data_]() { - absl::optional result; - std::exception_ptr exptr; - try { - result = data->taskfn(); - } catch (...) { - exptr = std::current_exception(); - } - - std::lock_guard lock(data->mutex); - if (result) { - data->result = std::move(*result); - } else { - data->exptr = std::move(exptr); - } - data->completed = true; - data->cv.notify_all(); - }; - - { - std::lock_guard lock(data_->mutex); - XLA_CHECK(!data_->scheduled); - data_->scheduled = true; - } - torch_xla::runtime::env::ScheduleIoClosure(std::move(completer)); - return *this; - } - - const T& GetValue() const { - std::lock_guard lock(data_->mutex); - return *data_->result; - } - - T ConsumeValue() { - std::lock_guard lock(data_->mutex); - return std::move(*data_->result); - } - - private: - std::shared_ptr data_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_ASYNC_TASK_H_ diff --git a/torch_xla/csrc/runtime/async_task_test.cc b/torch_xla/csrc/runtime/async_task_test.cc deleted file mode 100644 index 9b7a98c5a1f..00000000000 --- a/torch_xla/csrc/runtime/async_task_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "torch_xla/csrc/runtime/async_task.h" - -#include - -#include - -namespace torch_xla { -namespace runtime { - -TEST(AsyncTaskTest, BaseTest) { - auto taskfn = []() -> int { return 17; }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - async.Wait(); - EXPECT_EQ(async.GetValue(), 17); -} - -TEST(AsyncTaskTest, ExceptionTest) { - auto taskfn = []() -> int { throw std::runtime_error("Task Exception"); }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - bool got_exception = false; - try { - async.Wait(); - } catch (const std::exception&) { - got_exception = true; - } - EXPECT_TRUE(got_exception); -} - -TEST(AsyncTaskTest, NoResultCopyTest) { - struct Result { - Result(int* counter) : counter(counter) {} - Result(const Result& ref) : counter(ref.counter) { ++(*counter); } - - Result& operator=(const Result& ref) { - if (this != &ref) { - counter = ref.counter; - ++(*counter); - } - return *this; - } - - Result(Result&&) = default; - Result& operator=(Result&&) = default; - - int* counter = nullptr; - }; - - int copy_counter = 0; - auto taskfn = [&]() -> Result { return Result(©_counter); }; - - torch_xla::runtime::util::AsyncTask async(std::move(taskfn)); - async.Schedule(); - async.Wait(); - - Result result = async.ConsumeValue(); - EXPECT_EQ(copy_counter, 0); - EXPECT_EQ(result.counter, ©_counter); -} - -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index a09768c6a9e..9af461bfd56 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -1,6 +1,7 @@ #ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_ #define XLA_CLIENT_COMPUTATION_CLIENT_H_ +#include #include #include #include @@ -20,6 +21,7 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/metrics.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" @@ -30,6 +32,12 @@ namespace torch_xla { namespace runtime { +// Forward declare XlaCoordinator to avoid logging macro redefinition from the +// transitively included PJRT header. +// TODO(jonbolin): We need a way to ensure the right macros are included +// regardless of the import order. +class XlaCoordinator; + // Somehow the compiler doesn't allow type that has default member being // used as a default parameter in a method defined in the same scope. // Therefore, ClientExecuteOptions is defined here instead of within @@ -186,25 +194,6 @@ class ComputationClient { using ComputationPtr = std::shared_ptr; - // The TensorSource provides a way for a client to populate a buffer allocated - // by the core computation client code. - struct TensorSource { - // The PopulateFn accepts a dense buffer is standard array layout - // (dim0-major) and deposits the source tensor data directly over the - // provided buffer. - using PopulateFn = std::function; - - TensorSource() = default; - TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn) - : shape(std::move(shape)), - device(std::move(device)), - populate_fn(std::move(populate_fn)) {} - - xla::Shape shape; - std::string device; - PopulateFn populate_fn; - }; - // TODO(wcromar): Should CompileInstance still exist? Should it be a subclass // of torch::lazy::Computation? struct CompileInstance { @@ -269,19 +258,22 @@ class ComputationClient { // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToServer( - absl::Span tensors) = 0; + absl::Span> tensors) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) = 0; + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; // Copies `data->buffer` to `dst` device buffer. virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0; // Reads the tensor literal values stored at TPU server sites, behind the // supplied handles. + // Note: `TransferFromServer` call will block until the `DataPtrs` are ready + // if they were created by `TransferToServer` or `Execute*`. Calling this from + // python while holding the GIL can cause deadlocks! virtual std::vector TransferFromServer( absl::Span handles) = 0; @@ -329,7 +321,7 @@ class ComputationClient { virtual int GetNumProcesses() const = 0; using DeviceAttribute = - std::variant, float>; + std::variant, float, bool>; virtual const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& @@ -344,12 +336,21 @@ class ComputationClient { virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; - virtual void PrepareToExit() = 0; - // Block until pass in devices' async operation are finished. If empty, all // the local devices will be waited for. virtual void WaitDeviceOps(const std::vector& devices) = 0; + // Check whether the XlaCoordinator has been initialized. + virtual bool CoordinatorInitialized() const = 0; + + // Initialize the XlaCoordinator for the runtime. + virtual void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) = 0; + + // Return the XlaCoordinator for the runtime. + virtual XlaCoordinator& GetCoordinator() = 0; + // Utility API around the vector based Compile() API to compile a single // computation. ComputationPtr Compile(xla::XlaComputation computation, diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index 42040a9cca5..733574a4818 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -14,11 +14,15 @@ const char* const kEnvPjRtTpuMaxInflightComputations = const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT"; const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH"; +const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH"; const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH"; const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH"; const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR"; const char* const kEnvPjRtLocalProcessCount = "PJRT_LOCAL_PROCESS_COUNT"; const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK"; +const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC"; +const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE"; +const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION"; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index e54ba8f72cd..e7e1ef81964 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -24,11 +24,15 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations; extern const char* const kEnvPjrtAsyncCpuClient; extern const char* const kEnvPjrtAsyncGpuClient; extern const char* const kEnvTpuLibraryPath; +extern const char* const kEnvInferredTpuLibraryPath; extern const char* const kEnvXpuLibraryPath; extern const char* const kEnvNeuronLibraryPath; extern const char* const kEnvPjrtDistServiceAddr; extern const char* const kEnvPjRtLocalProcessCount; extern const char* const kEnvPjRtLocalRank; +extern const char* const kEnvPjrtAllocatorCudaAsync; +extern const char* const kEnvPjrtAllocatorPreallocate; +extern const char* const kEnvPjrtAllocatorFraction; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/multi_wait.cc b/torch_xla/csrc/runtime/multi_wait.cc deleted file mode 100644 index c4d0def062b..00000000000 --- a/torch_xla/csrc/runtime/multi_wait.cc +++ /dev/null @@ -1,73 +0,0 @@ -#include "torch_xla/csrc/runtime/multi_wait.h" - -#include -#include - -namespace torch_xla { -namespace runtime { -namespace util { - -void MultiWait::Done() { - bool notify = false; - { - std::lock_guard lock(mutex_); - completed_count_ += 1; - notify = completed_count_ == count_; - } - if (notify) { - cv_.notify_all(); - } -} - -void MultiWait::Wait() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return completed_count_ >= count_; }); - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } -} - -void MultiWait::Wait(double wait_seconds) { - std::unique_lock lock(mutex_); - if (!cv_.wait_for(lock, std::chrono::duration(wait_seconds), - [this] { return completed_count_ >= count_; })) { - throw std::runtime_error("Timeout"); - } - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } -} - -void MultiWait::Reset(size_t count) { - std::lock_guard lock(mutex_); - count_ = count; - completed_count_ = 0; - exptr_ = nullptr; -} - -std::function MultiWait::Completer(std::function func) { - auto completer = [this, func = std::move(func)]() { Complete(func); }; - return completer; -} - -std::function MultiWait::Completer(std::shared_ptr mwait, - std::function func) { - auto completer = [mwait = std::move(mwait), func = std::move(func)]() { - mwait->Complete(func); - }; - return completer; -} - -void MultiWait::Complete(const std::function& func) { - try { - func(); - } catch (...) { - std::lock_guard lock(mutex_); - exptr_ = std::current_exception(); - } - Done(); -} - -} // namespace util -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/multi_wait.h b/torch_xla/csrc/runtime/multi_wait.h deleted file mode 100644 index 9637850d555..00000000000 --- a/torch_xla/csrc/runtime/multi_wait.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef XLA_CLIENT_MULTI_WAIT_H_ -#define XLA_CLIENT_MULTI_WAIT_H_ - -#include -#include -#include -#include - -#include "xla/types.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Support waiting for a number of tasks to complete. -class MultiWait { - public: - explicit MultiWait(size_t count) : count_(count) {} - - // Signal the completion of a single task. - void Done(); - - // Waits until at least count (passed as constructor value) completions - // happened. - void Wait(); - - // Same as above, but waits up to wait_seconds. - void Wait(double wait_seconds); - - // Resets the threshold counter for the MultiWait object. The completed count - // is also reset to zero. - void Reset(size_t count); - - // Creates a completer functor which signals the mult wait object once func - // has completed. Handles exceptions by signaling the multi wait with the - // proper status value. This API returns a function which captures a MultiWait - // reference, so care must be taken such that the reference remains valid for - // the whole lifetime of the returned function. - std::function Completer(std::function func); - - // Similar as the above API, but with explicit capture of the MultiWait shared - // pointer. - static std::function Completer(std::shared_ptr mwait, - std::function func); - - private: - void Complete(const std::function& func); - - std::mutex mutex_; - std::condition_variable cv_; - size_t count_ = 0; - size_t completed_count_ = 0; - std::exception_ptr exptr_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_MULTI_WAIT_H_ diff --git a/torch_xla/csrc/runtime/nccl_distributed.cc b/torch_xla/csrc/runtime/nccl_distributed.cc deleted file mode 100644 index 51088913b88..00000000000 --- a/torch_xla/csrc/runtime/nccl_distributed.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "torch_xla/csrc/runtime/nccl_distributed.h" - -#include -#include - -#include "absl/strings/str_join.h" -#include "torch_xla/csrc/runtime/debug_macros.h" -#if XLA_CUDA -#include "third_party/nccl/nccl.h" -#endif - -namespace torch_xla { -namespace runtime { -namespace nccl_detail { - -#if XLA_CUDA - -namespace { - -class NcclUidManager { - public: - static NcclUidManager* Get(); - - std::string GetNcclUniqueUid(absl::Span replicas); - - private: - std::mutex mutex_; - std::map replicas_uid_map_; -}; - -NcclUidManager* NcclUidManager::Get() { - static NcclUidManager* nccl_mgr = new NcclUidManager(); - return nccl_mgr; -} - -std::string NcclUidManager::GetNcclUniqueUid( - absl::Span replicas) { - std::string replicas_str = absl::StrJoin(replicas, ","); - std::lock_guard lock(mutex_); - auto it = replicas_uid_map_.find(replicas_str); - if (it == replicas_uid_map_.end()) { - ncclUniqueId id; - ncclResult_t r = ncclGetUniqueId(&id); - XLA_CHECK_EQ(r, ncclSuccess) - << "NCCL UID generation failed: replicas=(" << replicas_str - << "), error: " << ncclGetErrorString(r); - it = replicas_uid_map_ - .emplace(std::move(replicas_str), - std::string(id.internal, NCCL_UNIQUE_ID_BYTES)) - .first; - } - return it->second; -} - -} // namespace - -std::string GetNcclUniqueUid(absl::Span replicas) { - return NcclUidManager::Get()->GetNcclUniqueUid(replicas); -} - -#else // XLA_CUDA - -std::string GetNcclUniqueUid(absl::Span replicas) { - XLA_ERROR() << "Calling GetNcclUniqueUid() without NCCL configuration"; -} - -#endif // XLA_CUDA - -} // namespace nccl_detail -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/nccl_distributed.h b/torch_xla/csrc/runtime/nccl_distributed.h deleted file mode 100644 index de5e0b0887d..00000000000 --- a/torch_xla/csrc/runtime/nccl_distributed.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef XLA_CLIENT_NCCL_DISTRIBUTED_H_ -#define XLA_CLIENT_NCCL_DISTRIBUTED_H_ - -#include - -#include "absl/types/span.h" -#include "xla/types.h" - -namespace torch_xla { -namespace runtime { -namespace nccl_detail { - -std::string GetNcclUniqueUid(absl::Span replicas); - -} // namespace nccl_detail -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_NCCL_DISTRIBUTED_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 4f175be7d71..9ad731eba82 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -5,20 +5,24 @@ #include #include "absl/strings/ascii.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" -#include "torch_xla/csrc/runtime/multi_wait.h" +#include "torch_xla/csrc/runtime/profiler.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" +#include "torch_xla/csrc/thread_pool.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" @@ -37,22 +41,6 @@ namespace { static std::string spmd_device_str = "SPMD:0"; -// Initializes a distributed runtime client if dist_service_addr is specified -std::shared_ptr -MaybeInitializeDistributedRuntimeClient(int local_rank, - std::string dist_service_addr) { - std::shared_ptr client; - if (!dist_service_addr.empty()) { - xla::DistributedRuntimeClient::Options options; - /* TODO(jonbolin): Use global rank for multi-host setup */ - options.node_id = local_rank; - client = xla::GetDistributedRuntimeClient(dist_service_addr, options); - XLA_CHECK(client->Connect().ok()) - << "Failed to initialize distributed runtime client"; - } - return std::move(client); -} - // Builds a map from the device's global ordinal to its index in the `devices` // array. std::unordered_map build_index_map( @@ -77,6 +65,23 @@ xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { return xla::ShapeUtil::DeviceShapeToHostShape(shape); } +xla::GpuAllocatorConfig GetGpuAllocatorConfig() { + auto allocator_config = xla::GpuAllocatorConfig{}; + if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && + sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { + return allocator_config; + } + if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { + allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; + } + allocator_config.preallocate = + sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); + allocator_config.memory_fraction = + sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); + return allocator_config; +} + } // namespace std::string PjRtComputationClient::PjRtDeviceToString( @@ -109,23 +114,39 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); + // Prefer $TPU_LIBRARY_PATH if set + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); - XLA_CHECK(tpu_status.ok()); + XLA_CHECK_OK(tpu_status); client_ = std::move(xla::GetCApiClient("TPU").value()); + const PJRT_Api* c_api = + static_cast(client_.get())->pjrt_c_api(); + profiler::RegisterProfilerForPlugin(c_api); } else if (device_type == "TPU_LEGACY") { XLA_ERROR() << "TPU_LEGACY client is no longer available."; - } else if (device_type == "GPU") { + } else if (device_type == "GPU" || device_type == "CUDA" || + device_type == "ROCM") { TF_VLOG(1) << "Initializing PjRt GPU client..."; bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); - int local_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); - std::string dist_service_addr = - sys_util::GetEnvString(env::kEnvPjrtDistServiceAddr, ""); - auto distributed_client = - MaybeInitializeDistributedRuntimeClient(local_rank, dist_service_addr); + int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); + int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + // Use the XlaCoordinator as the distributed key-value store. + coordinator_ = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator_->GetClient(); auto allowed_devices = - std::make_optional>(std::set{local_rank}); + std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; if (distributed_client != nullptr) { @@ -140,29 +161,32 @@ PjRtComputationClient::PjRtComputationClient() { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } - client_ = - std::move(xla::GetStreamExecutorGpuClient( - /*asynchronous=*/async, xla::GpuAllocatorConfig{}, - /*node_id=*/local_rank, - /*num_nodes=*/ - sys_util::GetEnvInt(env::kEnvPjRtLocalProcessCount, 1), - /*allowed_devices=*/allowed_devices, - /*platform_name*/ "gpu", - /*should_stage_host_to_device_transfers*/ true, - /*kv_get*/ kv_get, - /*kv_put*/ kv_put) - .value()); + TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" + << global_process_rank << ", num_nodes=" << global_world_size; + client_ = std::move(xla::GetStreamExecutorGpuClient( + /*asynchronous=*/async, + /*allocator_config=*/GetGpuAllocatorConfig(), + /*node_id=*/global_process_rank, + /*num_nodes=*/global_world_size, + /*allowed_devices=*/allowed_devices, + /*platform_name=*/"gpu", + /*should_stage_host_to_device_transfers=*/true, + /*kv_get=*/kv_get, + /*kv_put=*/kv_put) + .value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so"))); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) + .status()); client_ = std::move(xla::GetCApiClient("XPU").value()); - } else if (device_type == "NEURON") { TF_VLOG(1) << "Initializing PjRt NEURON client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "NEURON", sys_util::GetEnvString(env::kEnvNeuronLibraryPath, - "libneuronpjrt.so"))); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( + env::kEnvNeuronLibraryPath, + "libneuronpjrt.so")) + .status()); client_ = std::move(xla::GetCApiClient("NEURON").value()); } else { XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice, @@ -188,6 +212,33 @@ PjRtComputationClient::PjRtComputationClient() { device_locks_.emplace(spmd_device_str, std::make_unique()); } +PjRtComputationClient::~PjRtComputationClient() { + // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient + // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. + client_ = nullptr; + coordinator_ = nullptr; +} + +bool PjRtComputationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void PjRtComputationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& PjRtComputationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; +} + void PjRtComputationClient::PjRtData::Assign( const torch::lazy::BackendData& data) { const PjRtData& pjrt_data = dynamic_cast(data); @@ -259,7 +310,7 @@ std::optional PjRtComputationClient::GetDataSharding( } std::vector PjRtComputationClient::TransferToServer( - absl::Span tensors) { + absl::Span> tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -267,31 +318,22 @@ std::vector PjRtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); - - auto literal = std::make_shared(tensor.shape); - tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); - std::vector byte_strides(literal->shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), - absl::MakeSpan(byte_strides))); - total_size += literal->size_bytes(); - - // Avoid use-after-free on `literal` due to unsequenced move and use. - xla::Literal* literal_pointer = literal.get(); - std::shared_ptr buffer = std::move( - client_ - ->BufferFromHostBuffer( - literal_pointer->untyped_data(), - literal_pointer->shape().element_type(), - literal_pointer->shape().dimensions(), byte_strides, - xla::PjRtClient::HostBufferSemantics:: - kImmutableUntilTransferCompletes, - [literal{std::move(literal)}]() { /* frees literal */ }, - pjrt_device) - .value()); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); + + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); + + std::shared_ptr buffer = + std::move(client_ + ->BufferFromHostBuffer( + tensor->data(), tensor->primitive_type(), + tensor->dimensions(), tensor->byte_strides(), + xla::PjRtClient::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [tensor]() { /* frees tensor */ }, pjrt_device) + .value()); ComputationClient::DataPtr data = - std::make_shared(tensor.device, tensor.shape, buffer); + std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); @@ -301,8 +343,8 @@ std::vector PjRtComputationClient::TransferToServer( } ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) { + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "PjRtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -578,9 +620,9 @@ PjRtComputationClient::ExecuteComputation( } CreateDataHandlesCounter()->AddValue(datas.size()); - auto mwait = std::make_shared(1); - auto lockfn = [&, this, device, returned_future = std::move(*returned_future), - timed]() mutable { + thread::Schedule(std::move([&, this, device, + returned_future = std::move(*returned_future), + timed]() mutable { TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device; // Grab the shared lock and block the `WaitDeviceOps` until buffer is @@ -601,9 +643,7 @@ PjRtComputationClient::ExecuteComputation( timed.reset(); TF_VLOG(3) << "ExecuteComputation returned_future->OnReady finished"; }); - }; - - env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + })); TF_VLOG(1) << "Returning " << datas.size() << " results"; return datas; @@ -627,7 +667,7 @@ PjRtComputationClient::ExecuteReplicated( XLA_CHECK(devices.size() == arguments.size()) << "ExecuteReplicated over " << devices.size() << " devices, but " << arguments.size() << " arguments devices."; - auto mwait_argument = std::make_shared(devices.size()); + absl::BlockingCounter counter(devices.size()); std::vector> argument_handles(devices.size()); { tsl::profiler::TraceMe activity( @@ -648,11 +688,11 @@ PjRtComputationClient::ExecuteReplicated( buffers.push_back(pjrt_data->buffer.get()); } argument_handles[i] = std::move(buffers); + counter.DecrementCount(); }; - env::ScheduleIoClosure(util::MultiWait::Completer( - mwait_argument, std::move(buffer_converter))); + thread::Schedule(std::move(buffer_converter)); } - mwait_argument->Wait(); + counter.Wait(); } xla::ExecuteOptions execute_options; @@ -707,9 +747,9 @@ PjRtComputationClient::ExecuteReplicated( } } - auto mwait = std::make_shared(1); - auto lockfn = [&, this, returned_futures = std::move(*returned_futures), - timed]() mutable { + thread::Schedule(std::move([&, this, + returned_futures = std::move(*returned_futures), + timed]() mutable { // Grab the shared lock and block the `WaitDeviceOps` until buffer is // ready. Since this is the SPMD code path. There is no points to grab // devices lock for every individual device. @@ -730,8 +770,7 @@ PjRtComputationClient::ExecuteReplicated( timed.reset(); TF_VLOG(3) << "ExecuteReplicated returned_future->OnReady finished"; }); - }; - env::ScheduleIoClosure(util::MultiWait::Completer(mwait, std::move(lockfn))); + })); TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results " << "with dimensions [" << absl::StrJoin(dims, ",") << "]."; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index d7a11611a03..b66e4ff5097 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -23,6 +23,7 @@ namespace runtime { class PjRtComputationClient : public ComputationClient { public: PjRtComputationClient(); + ~PjRtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; @@ -36,7 +37,7 @@ class PjRtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToServer( - absl::Span tensors) override; + absl::Span> tensors) override; // Use XLA replication to re-assemble the sharded data. DataPtr ReplicateShardedData(const DataPtr& handle); @@ -44,9 +45,9 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferFromServer( absl::Span handles) override; - DataPtr TransferShardsToServer(absl::Span tensor_shards, - std::string device, xla::Shape shape, - xla::OpSharding sharding) override; + DataPtr TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; @@ -85,12 +86,18 @@ class PjRtComputationClient : public ComputationClient { std::shared_ptr> GetReplicationDevices() override; - void PrepareToExit() override { return; }; - void WaitDeviceOps(const std::vector& devices) override; std::map GetMetrics() const override; + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + // NOT IMPLEMENTED MemoryInfo GetMemoryInfo(const std::string& device) override { @@ -99,6 +106,7 @@ class PjRtComputationClient : public ComputationClient { private: std::shared_ptr client_; + std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc index 24cbc4636a6..d6240f08e98 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc @@ -7,6 +7,8 @@ #include #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" @@ -32,17 +34,6 @@ tsl::StatusOr MakeComputation() { return builder.Build(); } -ComputationClient::TensorSource TensorSourceFromLiteral( - const std::string& device, const xla::Literal& literal) { - auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - std::memcpy(dest_buffer, literal.data().data(), - dest_buffer_size * sizeof(literal.data().data())); - }; - return ComputationClient::TensorSource(literal.shape(), device, - std::move(populate_fn)); -} - TEST(PjRtComputationClientTest, Init) { // Get a CPU client. tsl::setenv("PJRT_DEVICE", "CPU", true); @@ -69,9 +60,9 @@ TEST(PjRtComputationClientTest, Init) { // Copy inputs to device. ComputationClient::ExecuteComputationOptions options{}; - std::vector args = { - TensorSourceFromLiteral(device, literal_x), - TensorSourceFromLiteral(device, literal_y)}; + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; // Execute the graph. std::vector results = client->ExecuteComputation( diff --git a/torch_xla/csrc/runtime/profiler.cc b/torch_xla/csrc/runtime/profiler.cc index 41de76ebd5e..a2ea89be16d 100644 --- a/torch_xla/csrc/runtime/profiler.cc +++ b/torch_xla/csrc/runtime/profiler.cc @@ -1,14 +1,36 @@ #include "torch_xla/csrc/runtime/profiler.h" #include "absl/container/flat_hash_map.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "tsl/platform/status.h" +#include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/rpc/client/capture_profile.h" #include "tsl/profiler/rpc/profiler_server.h" +#include "xla/backends/profiler/plugin/plugin_tracer.h" +#include "xla/backends/profiler/plugin/profiler_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" namespace torch_xla { namespace runtime { namespace profiler { +namespace { + +const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) { + const PJRT_Structure_Base* next = + reinterpret_cast(pjrt_api->extension_start); + while (next != nullptr && + next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) { + next = next->next; + } + if (next == nullptr) { + return nullptr; + } + return reinterpret_cast(next)->profiler_api; +} + +} // namespace + struct ProfilerServer::Impl { Impl() : server(new tsl::profiler::ProfilerServer()) {} @@ -33,6 +55,19 @@ tsl::Status Trace( /*include_dataset_ops=*/false, duration_ms, num_tracing_attempts, options); } + +void RegisterProfilerForPlugin(const PJRT_Api* c_api) { + const PLUGIN_Profiler_Api* profiler_api = FindProfilerApi(c_api); + XLA_CHECK(profiler_api); + + tsl::profiler::ProfilerFactory create_func = + [profiler_api](const tensorflow::ProfileOptions& options) { + return std::make_unique(profiler_api, + options); + }; + tsl::profiler::RegisterProfilerFactory(std::move(create_func)); +} + } // namespace profiler } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/profiler.h b/torch_xla/csrc/runtime/profiler.h index 639e6b2a6d1..d5d49540c24 100644 --- a/torch_xla/csrc/runtime/profiler.h +++ b/torch_xla/csrc/runtime/profiler.h @@ -5,6 +5,7 @@ #include "absl/container/flat_hash_map.h" #include "tsl/platform/status.h" +#include "xla/pjrt/c/pjrt_c_api.h" namespace torch_xla { namespace runtime { @@ -28,6 +29,8 @@ tsl::Status Trace( const absl::flat_hash_map>& options); +void RegisterProfilerForPlugin(const PJRT_Api* c_api); + } // namespace profiler } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 8cfd0695184..69e5bb74319 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -10,10 +10,11 @@ namespace torch_xla { namespace runtime { namespace { -std::atomic g_computation_client(nullptr); -std::once_flag g_computation_client_once; +std::atomic g_computation_client_initialized(false); ComputationClient* CreateClient() { + bool was_initialized = g_computation_client_initialized.exchange(true); + XLA_CHECK(!was_initialized) << "ComputationClient already initialized"; if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { tsl::testing::InstallStacktraceHandler(); } @@ -23,6 +24,7 @@ ComputationClient* CreateClient() { if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { client = new PjRtComputationClient(); } else { + g_computation_client_initialized = false; XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; } @@ -34,13 +36,12 @@ ComputationClient* CreateClient() { } // namespace ComputationClient* GetComputationClient() { - std::call_once(g_computation_client_once, - [&]() { g_computation_client = std::move(CreateClient()); }); - return g_computation_client.load(); + static auto client = std::unique_ptr(CreateClient()); + return client.get(); } ComputationClient* GetComputationClientIfInitialized() { - return g_computation_client.load(); + return g_computation_client_initialized ? GetComputationClient() : nullptr; } } // namespace runtime diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h new file mode 100644 index 00000000000..11d4b2f71a5 --- /dev/null +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -0,0 +1,100 @@ +#ifndef XLA_CLIENT_TENSOR_SOURCE_H_ +#define XLA_CLIENT_TENSOR_SOURCE_H_ + +#include +#include + +#include + +#include "torch_xla/csrc/dtype.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace runtime { + +// Owns a contiguous block of data with the shape and layout matching `shape()`. +class TensorSource { + public: + TensorSource(std::string device) : device_(std::move(device)){}; + + virtual const void* data() const = 0; + + virtual const xla::Shape& shape() const = 0; + + const std::string& device() const { return device_; } + + virtual std::vector byte_strides() const { + std::vector byte_strides(shape().dimensions_size()); + XLA_CHECK_OK( + xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); + return byte_strides; + } + + virtual std::vector dimensions() const { + auto dimensions = shape().dimensions(); + return {dimensions.begin(), dimensions.end()}; + } + + virtual xla::PrimitiveType primitive_type() const { + return shape().element_type(); + } + + private: + std::string device_; +}; + +class AtenSource : public TensorSource { + public: + AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) + : TensorSource(std::move(device)), shape_(std::move(shape)) { + at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); + if (target_torch_type != tensor.type().scalarType()) { + TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); + tensor_ = std::move(tensor.to(target_torch_type).contiguous()); + } else { + tensor_ = std::move(tensor.contiguous()); + } + } + + const void* data() const override { return tensor_.const_data_ptr(); } + + const xla::Shape& shape() const override { return shape_; } + + std::vector byte_strides() const override { + std::vector strides; + for (auto& stride : tensor_.strides()) { + strides.push_back(stride * tensor_.itemsize()); + } + return strides; + } + + std::vector dimensions() const override { + auto sizes = tensor_.sizes(); + return {sizes.begin(), sizes.end()}; + } + + private: + at::Tensor tensor_; + xla::Shape shape_; +}; + +class LiteralSource : public TensorSource { + public: + LiteralSource(xla::Literal literal, std::string device) + : TensorSource(std::move(device)), literal_(std::move(literal)) {} + + const void* data() const override { return literal_.untyped_data(); } + + const xla::Shape& shape() const override { return literal_.shape(); } + + private: + xla::Literal literal_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/thread_pool.cc b/torch_xla/csrc/runtime/thread_pool.cc deleted file mode 100644 index fa0212e3a26..00000000000 --- a/torch_xla/csrc/runtime/thread_pool.cc +++ /dev/null @@ -1,183 +0,0 @@ -#include "torch_xla/csrc/runtime/thread_pool.h" - -#include -#include -#include -#include - -#include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/tf_logging.h" - -namespace torch_xla { -namespace runtime { -namespace env { -namespace { - -class ThreadPool { - public: - explicit ThreadPool(size_t num_threads) { - threads_.reserve(num_threads); - for (size_t i = 0; i < num_threads; ++i) { - threads_.emplace_back([this]() { Worker(); }); - } - } - - ~ThreadPool() { - { - std::lock_guard lock(mutex_); - exiting_ = true; - cv_.notify_all(); - } - for (auto& thread : threads_) { - thread.join(); - } - } - - void Schedule(std::function closure) { - // If we have more work scheduled than waiting worker threads, just schedule - // it on a separate thread. This prevents tricky thread-pool-size-deadlocks - // caused by an undersized thread pool and closures that end up doing sync - // waits on the pool threads. - bool scheduled = false; - { - std::lock_guard lock(mutex_); - if (work_.size() < waiting_) { - work_.emplace_back(std::move(closure)); - scheduled = true; - } - } - if (scheduled) { - cv_.notify_one(); - } else { - ScheduleOnThread(std::move(closure)); - } - } - - private: - void Worker() { - while (true) { - std::function closure = GetWork(); - if (closure == nullptr) { - break; - } - try { - closure(); - } catch (const std::exception& ex) { - XLA_COUNTER("ThreadPoolException", 1); - TF_LOG(ERROR) << "Exception from running thread pool closure: " - << ex.what(); - } - } - } - - void ScheduleOnThread(std::function closure) { - std::thread thread(std::move(closure)); - thread.detach(); - } - - std::function GetWork() { - std::unique_lock lock(mutex_); - ++waiting_; - cv_.wait(lock, [this] { return exiting_ || !work_.empty(); }); - --waiting_; - if (work_.empty()) { - return nullptr; - } - std::function closure(std::move(work_.front())); - work_.pop_front(); - return closure; - } - - std::vector threads_; - std::mutex mutex_; - std::condition_variable cv_; - bool exiting_ = false; - std::deque> work_; - size_t waiting_ = 0; -}; - -ThreadPool* GetThreadPool() { - static size_t num_threads = sys_util::GetEnvInt( - "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); - static ThreadPool* pool = new ThreadPool(num_threads); - return pool; -} - -ThreadPool* GetIoThreadPool() { - static size_t num_threads = sys_util::GetEnvInt( - "XLA_IO_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); - static ThreadPool* pool = new ThreadPool(num_threads); - return pool; -} - -} // namespace - -class Completion::Data { - public: - void Wait() { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return completed_; }); - if (exptr_ != nullptr) { - std::rethrow_exception(exptr_); - } - } - - static std::function GetCompleter(std::shared_ptr data, - std::function closure) { - auto closure_wrapper = [closure = std::move(closure), data]() { - std::exception_ptr exptr; - try { - closure(); - } catch (...) { - exptr = std::current_exception(); - } - data->Complete(exptr); - }; - return closure_wrapper; - } - - private: - void Complete(std::exception_ptr exptr) { - std::lock_guard lock(mutex_); - exptr_ = std::move(exptr); - completed_ = true; - cv_.notify_all(); - } - - std::mutex mutex_; - std::condition_variable cv_; - bool completed_ = false; - std::exception_ptr exptr_; -}; - -Completion::Completion(std::shared_ptr data) : data_(std::move(data)) {} - -Completion::~Completion() {} - -void Completion::Wait() { data_->Wait(); } - -void ScheduleClosure(std::function closure) { - GetThreadPool()->Schedule(std::move(closure)); -} - -void ScheduleIoClosure(std::function closure) { - GetIoThreadPool()->Schedule(std::move(closure)); -} - -Completion ScheduleClosureWithCompletion(std::function closure) { - auto data = std::make_shared(); - GetThreadPool()->Schedule( - Completion::Data::GetCompleter(data, std::move(closure))); - return Completion(std::move(data)); -} - -Completion ScheduleIoClosureWithCompletion(std::function closure) { - auto data = std::make_shared(); - GetIoThreadPool()->Schedule( - Completion::Data::GetCompleter(data, std::move(closure))); - return Completion(std::move(data)); -} - -} // namespace env -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/thread_pool.h b/torch_xla/csrc/runtime/thread_pool.h deleted file mode 100644 index 072e28594cc..00000000000 --- a/torch_xla/csrc/runtime/thread_pool.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef XLA_CLIENT_THREAD_POOL_H_ -#define XLA_CLIENT_THREAD_POOL_H_ - -#include -#include -#include - -namespace torch_xla { -namespace runtime { -namespace env { - -class Completion { - public: - class Data; - - explicit Completion(std::shared_ptr data); - - ~Completion(); - - void Wait(); - - private: - std::shared_ptr data_; -}; - -// Schedules a closure to be run. The closure should not block waiting for other -// events. -void ScheduleClosure(std::function closure); -Completion ScheduleClosureWithCompletion(std::function closure); - -// Schedules a closure which might wait for IO or other events/conditions. -void ScheduleIoClosure(std::function closure); -Completion ScheduleIoClosureWithCompletion(std::function closure); - -} // namespace env -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_THREAD_POOL_H_ diff --git a/torch_xla/csrc/runtime/types.h b/torch_xla/csrc/runtime/types.h index de71cd6c2ef..a27f1a0e1c2 100644 --- a/torch_xla/csrc/runtime/types.h +++ b/torch_xla/csrc/runtime/types.h @@ -11,8 +11,6 @@ namespace torch_xla { namespace runtime { -using hash_t = absl::uint128; - struct Percentile { enum class UnitOfMeaure { kNumber, diff --git a/torch_xla/csrc/runtime/unique.h b/torch_xla/csrc/runtime/unique.h deleted file mode 100644 index f50e24320d9..00000000000 --- a/torch_xla/csrc/runtime/unique.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef XLA_CLIENT_UNIQUE_H_ -#define XLA_CLIENT_UNIQUE_H_ - -#include -#include - -#include "absl/types/optional.h" -#include "torch_xla/csrc/runtime/debug_macros.h" - -namespace torch_xla { -namespace runtime { -namespace util { - -// Helper class to allow tracking zero or more things, which should be forcibly -// be one only thing. -template > -class Unique { - public: - std::pair set(const T& value) { - if (value_) { - XLA_CHECK(C()(*value_, value)) - << "'" << *value_ << "' vs '" << value << "'"; - return std::pair(false, *value_); - } - value_ = value; - return std::pair(true, *value_); - } - - operator bool() const { return value_.has_value(); } - operator const T&() const { return *value_; } - const T& operator*() const { return *value_; } - const T* operator->() const { return value_.operator->(); } - - std::set AsSet() const { - std::set vset; - if (value_.has_value()) { - vset.insert(*value_); - } - return vset; - } - - private: - absl::optional value_; -}; - -} // namespace util -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_UNIQUE_H_ diff --git a/torch_xla/csrc/runtime/util.cc b/torch_xla/csrc/runtime/util.cc deleted file mode 100644 index caeeb149492..00000000000 --- a/torch_xla/csrc/runtime/util.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "torch_xla/csrc/runtime/util.h" - -#include - -namespace torch_xla { -namespace runtime { -namespace util { -namespace { - -hash_t LoadHash(const uint8_t** data, const uint8_t* top) { - std::ptrdiff_t size = top - (*data); - if (size >= sizeof(hash_t)) { - hash_t v; - std::memcpy(&v, *data, sizeof(v)); - *data += sizeof(hash_t); - return v; - } - - union { - hash_t h; - uint8_t b[sizeof(hash_t)]; - } uval; - uval.h = 0; - std::memcpy(uval.b, *data, size); - *data += size; - return uval.h; -} - -} // namespace - -hash_t HashBlock(const void* data, size_t n, const hash_t& seed) { - const hash_t m = 0xc6a4a7935bd1e995; - const int r = 47; - - const uint8_t* u8_data = reinterpret_cast(data); - const uint8_t* top = u8_data + n; - hash_t h = seed ^ (n * m); - while (u8_data < top) { - hash_t k = LoadHash(&u8_data, top); - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - h ^= h >> r; - h *= m; - h ^= h >> r; - return h; -} - -hash_t DataHash(const void* data, size_t size) { - return HashBlock(data, size, 0xc2b2ae3d27d4eb4f); -} - -size_t StdDataHash(const void* data, size_t size) { - return HashReduce(DataHash(data, size)); -} - -size_t StdHashCombine(uintmax_t a, uintmax_t b) { - return a ^ - (b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); -} - -hash_t HashCombine(const hash_t& a, const hash_t& b) { - static const hash_t kb = absl::MakeUint128(101, 0x27d4eb2f165667c5); - return a ^ (b * kb + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); -} - -size_t HashReduce(const hash_t& a) { - return StdHashCombine(absl::Uint128Low64(a), absl::Uint128High64(a)); -} - -std::string HexHash(const hash_t& a) { - std::stringstream ss; - ss << std::hex << a; - return ss.str(); -} - -} // namespace util -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util.h b/torch_xla/csrc/runtime/util.h index c4d05593d84..722a6591f78 100644 --- a/torch_xla/csrc/runtime/util.h +++ b/torch_xla/csrc/runtime/util.h @@ -24,75 +24,6 @@ namespace torch_xla { namespace runtime { namespace util { -hash_t HashBlock(const void* data, size_t n, const hash_t& seed); - -hash_t DataHash(const void* data, size_t size); - -size_t StdDataHash(const void* data, size_t size); - -size_t StdHashCombine(uintmax_t a, uintmax_t b); - -hash_t HashCombine(const hash_t& a, const hash_t& b); - -size_t HashReduce(const hash_t& a); - -std::string HexHash(const hash_t& a); - -struct HashReducer { - size_t operator()(const hash_t& value) const { return HashReduce(value); } -}; - -template -xla::Status CheckedCall(const F& fn) { - try { - fn(); - } catch (const std::exception& ex) { - return tsl::errors::Internal(ex.what()); - } - return xla::Status(); -} - -template -class Cleanup { - public: - using StatusType = T; - - explicit Cleanup(std::function func) - : func_(std::move(func)) {} - Cleanup(Cleanup&& ref) - : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {} - Cleanup(const Cleanup&) = delete; - - ~Cleanup() { - if (func_ != nullptr) { - func_(std::move(status_)); - } - } - - Cleanup& operator=(const Cleanup&) = delete; - - Cleanup& operator=(Cleanup&& ref) { - if (this != &ref) { - func_ = std::move(ref.func_); - status_ = std::move(ref.status_); - } - return *this; - } - - void Release() { func_ = nullptr; } - - void SetStatus(StatusType status) { status_ = std::move(status); } - - const StatusType& GetStatus() const { return status_; } - - private: - std::function func_; - StatusType status_; -}; - -using ExceptionCleanup = Cleanup; -using StatusCleanup = Cleanup; - // Allows APIs which might return const references and values, to not be forced // to return values in the signature. template @@ -114,10 +45,6 @@ class MaybeRef { const T& ref_; }; -struct MidPolicy { - size_t operator()(size_t size) const { return size / 2; } -}; - template class MaybePtr { public: @@ -139,70 +66,6 @@ class MaybePtr { absl::optional storage_; }; -// Hasher for string-like objects which hashes only a partial window of the data -// of size N. The P (policy) type is a functor which returns the position of the -// window. -template -struct PartialHasher { - size_t operator()(const T& data) const { - size_t pos = policy(data.size()); - size_t end = pos + N; - if (end > data.size()) { - end = data.size(); - if (N > data.size()) { - pos = 0; - } else { - pos = end - N; - } - } - return tsl::Hash64(data.data() + pos, end - pos, 17); - } - - P policy; -}; - -template -std::vector GetConstSharedPointers( - const C& shared_pointers) { - std::vector pointers; - pointers.reserve(shared_pointers.size()); - for (auto& shared_pointer : shared_pointers) { - pointers.push_back(shared_pointer.get()); - } - return pointers; -} - -template -std::vector GetSharedPointers( - const C& shared_pointers) { - std::vector pointers; - pointers.reserve(shared_pointers.size()); - for (auto& shared_pointer : shared_pointers) { - pointers.push_back(shared_pointer.get()); - } - return pointers; -} - -template -void InsertCombined(C* map, const K& key, const T& value, const F& combiner) { - auto it = map->find(key); - if (it == map->end()) { - map->emplace(key, value); - } else { - it->second = combiner(it->second, value); - } -} - -template -std::vector Iota(size_t size, T init = 0, T incr = 1) { - std::vector result(size); - T value = init; - for (size_t i = 0; i < size; ++i, value += incr) { - result[i] = value; - } - return result; -} - template std::vector Range(T start, T end, T step = 1) { std::vector result; @@ -260,76 +123,12 @@ const typename T::mapped_type& MapInsert(T* cont, return it->second; } -template -typename std::underlying_type::type GetEnumValue(T value) { - return static_cast::type>(value); -} - template T Multiply(const S& input) { return std::accumulate(input.begin(), input.end(), T(1), std::multiplies()); } -static inline hash_t StringHash(const char* data) { - return DataHash(data, std::strlen(data)); -} - -template ::value>::type* = nullptr> -hash_t Hash(const T& value) { - return DataHash(&value, sizeof(value)); -} - -static inline hash_t Hash(const std::string& value) { - return DataHash(value.data(), value.size()); -} - -// Forward declare to allow hashes of vectors of vectors to work. -template -hash_t ContainerHash(const T& values); - -template -hash_t Hash(absl::Span values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::vector& values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::set& values) { - return ContainerHash(values); -} - -template -hash_t Hash(const std::pair& values) { - return HashCombine(Hash(values.first), Hash(values.second)); -} - -static inline hash_t Hash(const hash_t& value) { return value; } - -template -hash_t ContainerHash(const T& values) { - hash_t h = 0x85ebca77c2b2ae63; - for (auto& value : values) { - h = HashCombine(h, Hash(value)); - } - return h; -} - -template -hash_t MHash() { - return 0x165667b19e3779f9; -} - -template -hash_t MHash(T value, Targs... Fargs) { - return HashCombine(Hash(value), MHash(Fargs...)); -} - } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/util_test.cc b/torch_xla/csrc/runtime/util_test.cc index 125e8ba8bef..f65eea30ca7 100644 --- a/torch_xla/csrc/runtime/util_test.cc +++ b/torch_xla/csrc/runtime/util_test.cc @@ -15,36 +15,6 @@ namespace util { using ::testing::ElementsAre; -TEST(UtilTest, Cleanup) { - bool notify = false; - - // Set to true. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(true); - } - EXPECT_TRUE(notify); - - // Set to false. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(false); - } - EXPECT_FALSE(notify); - - // Releasing the cleanup will not change the `notify` to true. - { - Cleanup c([¬ify](bool b) { notify = b; }); - c.SetStatus(true); - c.Release(); - } - EXPECT_FALSE(notify); -} - -TEST(UtilTest, Iota) { - EXPECT_THAT(Iota(5, 0, 2), ElementsAre(0, 2, 4, 6, 8)); -} - TEST(UtilTest, Range) { EXPECT_THAT(Range(0, 10, 2), ElementsAre(0, 2, 4, 6, 8)); EXPECT_THAT(Range(10, 0, -2), ElementsAre(10, 8, 6, 4, 2)); @@ -75,14 +45,6 @@ TEST(UtilTest, MapInsert) { EXPECT_EQ(MapInsert(&v, 1, [] { return 12; }), 1); } -TEST(UtilTest, GetEnumValue) { - enum E { A = 0, B, C, D }; - EXPECT_EQ(GetEnumValue(E::A), 0); - EXPECT_EQ(GetEnumValue(E::B), 1); - EXPECT_EQ(GetEnumValue(E::C), 2); - EXPECT_EQ(GetEnumValue(E::D), 3); -} - TEST(UtilTest, Multiply) { std::vector t = {1, 2, 3, 4, 5}; EXPECT_EQ(Multiply(t), 120); @@ -90,21 +52,6 @@ TEST(UtilTest, Multiply) { EXPECT_EQ(Multiply(t), 720); } -TEST(UtilTest, Hash) { - std::pair temp = {"hello", 3}; - EXPECT_EQ(Hash(std::pair{"hello", 3}), Hash(temp)); - EXPECT_EQ(HexHash(Hash(std::pair{"hello", 3})), - HexHash(Hash(temp))); - - std::vector t = {1, 2, 3, 4, 5}; - EXPECT_EQ(Hash({1, 2, 3, 4, 5}), Hash({1, 2, 3, 4, 5})); - EXPECT_EQ(Hash(std::set{1, 2, 3}), Hash(std::set{1, 2, 3})); - EXPECT_EQ(Hash(t), Hash(std::vector{1, 2, 3, 4, 5})); - - EXPECT_EQ(StdDataHash(t.data(), t.size()), - StdDataHash(std::vector{1, 2, 3, 4, 5}.data(), t.size())); -} - TEST(UtilTest, MaybeRef) { using StringRef = torch_xla::runtime::util::MaybeRef; std::string storage("String storage"); diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc new file mode 100644 index 00000000000..72855d8681e --- /dev/null +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -0,0 +1,73 @@ +#include "torch_xla/csrc/runtime/xla_coordinator.h" + +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/pjrt/distributed/distributed.h" + +namespace torch_xla { +namespace runtime { + +XlaCoordinator::XlaCoordinator(int global_rank, int world_size, + std::string master_addr, std::string port) { + std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":"); + if (global_rank == 0) { + xla::CoordinationServiceImpl::Options service_options; + service_options.num_nodes = world_size; + xla::StatusOr> + dist_runtime_service = xla::GetDistributedRuntimeService( + dist_service_addr, service_options); + XLA_CHECK(dist_runtime_service.ok()) + << "Failed to initialize distributed runtime service."; + dist_runtime_service_ = std::move(dist_runtime_service.value()); + } + + xla::DistributedRuntimeClient::Options client_options; + client_options.node_id = global_rank; + dist_runtime_client_ = + xla::GetDistributedRuntimeClient(dist_service_addr, client_options); + XLA_CHECK(dist_runtime_client_->Connect().ok()) + << "Failed to initialize distributed runtime client"; +} + +XlaCoordinator::~XlaCoordinator() { + preemption_sync_manager_ = nullptr; + if (dist_runtime_client_ != nullptr) { + XLA_CHECK(dist_runtime_client_->Shutdown().ok()) + << "Failed to shut down the distributed runtime client."; + dist_runtime_client_ = nullptr; + } + if (dist_runtime_service_ != nullptr) { + dist_runtime_service_->Shutdown(); + dist_runtime_service_ = nullptr; + } +} + +std::shared_ptr XlaCoordinator::GetClient() { + XLA_CHECK(dist_runtime_client_ != nullptr) + << "distributed runtime client is null."; + return dist_runtime_client_; +} + +void XlaCoordinator::ActivatePreemptionSyncManager() { + if (preemption_sync_manager_ == nullptr) { + preemption_sync_manager_ = std::move(tsl::CreatePreemptionSyncManager()); + auto client = dist_runtime_client_->GetCoordinationServiceAgent(); + XLA_CHECK(client.ok()) << "Failed to retrieve the CoodinationServiceAgent"; + auto status = preemption_sync_manager_->Initialize(client.value()); + XLA_CHECK(status.ok()) << "Failed to initialize the PreemptionSyncManager"; + } +} + +void XlaCoordinator::DeactivatePreemptionSyncManager() { + preemption_sync_manager_ = nullptr; +} + +bool XlaCoordinator::ReachedSyncPoint(int step) { + XLA_CHECK(preemption_sync_manager_ != nullptr) + << "A PreemptionSyncManager has not been registered with the " + "XlaCoordinator."; + return preemption_sync_manager_->ReachedSyncPoint(step); +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h new file mode 100644 index 00000000000..ae85c79a941 --- /dev/null +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -0,0 +1,53 @@ +#ifndef PTXLA_RUNTIME_COORDINATOR_H_ +#define PTXLA_RUNTIME_COORDINATOR_H_ + +#include + +#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/pjrt/distributed/distributed.h" + +namespace torch_xla { +namespace runtime { + +// XlaCoordinator serves as the point of entry for all operations which +// required the XLA distributed runtime, such as preemption coordination. +class XlaCoordinator { + public: + static inline const std::string kDefaultCoordinatorPort = "8547"; + + XlaCoordinator(int global_rank, int world_size, std::string master_addr, + std::string port); + + ~XlaCoordinator(); + + // Retrieve the DistributedRuntimeClient. + std::shared_ptr GetClient(); + + // Register a PreemptionSyncManager for the distributed runtime if none is + // active. The PreemptionSyncManager will register a SIGTERM handler, and + // when any host has received a preemption notice, all hosts are made aware + // through the ReachedSyncPoint API. See the documentation of + // tsl::PreemptionSyncManager for the full semantics: + // https://github.com/google/tsl/blob/3bbe663/tsl/distributed_runtime/preemption/preemption_sync_manager.h#L34 + void ActivatePreemptionSyncManager(); + + // If the PreemptionSyncManager is active, this will deactivate it and + // destroy the current instance. + void DeactivatePreemptionSyncManager(); + + // A pass-through API to PreemptionSyncManager::ReachedSyncPoint. + // The PreemptionSyncManager must be activated within the XlaCoordinator. + // Returns true when the input step has been identified as a sync point, and + // false otherwise. + bool ReachedSyncPoint(int step); + + private: + std::unique_ptr dist_runtime_service_; + std::shared_ptr dist_runtime_client_; + std::unique_ptr preemption_sync_manager_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif // PTXLA_RUNTIME_COORDINATOR_H_ diff --git a/torch_xla/csrc/runtime/xla_util.cc b/torch_xla/csrc/runtime/xla_util.cc index e591198bf7e..5eb9e009128 100644 --- a/torch_xla/csrc/runtime/xla_util.cc +++ b/torch_xla/csrc/runtime/xla_util.cc @@ -1,5 +1,7 @@ #include "torch_xla/csrc/runtime/xla_util.h" +#include + #include #include #include @@ -19,16 +21,17 @@ namespace runtime { namespace util { namespace { -hash_t SingleShapeHash(const xla::Shape& shape, hash_t seed) { +torch::lazy::hash_t SingleShapeHash(const xla::Shape& shape, + torch::lazy::hash_t seed) { if (shape.has_layout()) { for (auto dim : shape.layout().minor_to_major()) { - seed = HashCombine(seed, dim); + seed = torch::lazy::HashCombine(seed, dim); } } for (auto dim : shape.dimensions()) { - seed = HashCombine(seed, dim); + seed = torch::lazy::HashCombine(seed, dim); } - return HashCombine(seed, static_cast(shape.element_type())); + return torch::lazy::HashCombine(seed, static_cast(shape.element_type())); } void MaybeSaveHloGraph(const std::string& hlo_text, size_t index) { @@ -103,8 +106,8 @@ void CheckComputationStatus( } } -hash_t ShapeHash(const xla::Shape& shape) { - hash_t hash = 0xa5d2d6916; +torch::lazy::hash_t ShapeHash(const xla::Shape& shape) { + torch::lazy::hash_t hash = 0xa5d2d6916; xla::ShapeUtil::ForEachSubshape( shape, [&](const xla::Shape& subshape, const xla::ShapeIndex&) { hash = SingleShapeHash(subshape, hash); diff --git a/torch_xla/csrc/runtime/xla_util.h b/torch_xla/csrc/runtime/xla_util.h index 32b76f69eb9..3163d5ba8c4 100644 --- a/torch_xla/csrc/runtime/xla_util.h +++ b/torch_xla/csrc/runtime/xla_util.h @@ -1,6 +1,8 @@ #ifndef XLA_CLIENT_XLA_UTIL_H_ #define XLA_CLIENT_XLA_UTIL_H_ +#include + #include #include "absl/types/span.h" @@ -35,7 +37,7 @@ void CheckComputationStatus( absl::Span computations, absl::Span output_shapes); -hash_t ShapeHash(const xla::Shape& shape); +torch::lazy::hash_t ShapeHash(const xla::Shape& shape); } // namespace util } // namespace runtime diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 065f7db8a33..96465abf44c 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -19,7 +19,9 @@ #include #include +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/debug_util.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/ops/arithmetic_ir_ops.h" @@ -36,8 +38,6 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" @@ -139,9 +139,9 @@ XLATensor::XLATensor(std::shared_ptr view, XLATensor::XLATensor(std::shared_ptr data) : torch::lazy::LazyTensor(data), data_(std::move(data)), - storage_(c10::Storage( - {}, 0, - c10::DataPtr(nullptr, backendDeviceToAtenDevice(data_->device)))) {} + storage_(c10::Storage({}, 0, + c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice( + data_->device)))) {} auto XLATensor::data() const -> const std::shared_ptr& { XLA_CHECK(data_ != nullptr) << "Trying to access a null cursor"; @@ -158,7 +158,7 @@ int64_t XLATensor::size(int64_t dim) const { at::ScalarType XLATensor::dtype() const { return data()->logical_element_type ? *data()->logical_element_type - : TensorTypeFromXlaType(shape().get().element_type()); + : MaybeUpcastToHostTorchType(shape().get().element_type()); } c10::optional XLATensor::dtype_optional() const { @@ -640,7 +640,7 @@ c10::SymNode XLASymNodeImpl::add(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::sub(const c10::SymNode& other) { - TORCH_LAZY_FN_COUNTER("xla::size_"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); torch_xla::XLASymNodeImpl* p_other = dynamic_cast(other.get()); @@ -679,7 +679,7 @@ c10::SymNode XLASymNodeImpl::floordiv(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::mod(const c10::SymNode& other) { - TORCH_LAZY_FN_COUNTER("xla::size_"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); torch_xla::XLASymNodeImpl* p_other = dynamic_cast(other.get()); XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI"; @@ -698,7 +698,7 @@ c10::SymNode XLASymNodeImpl::eq(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::ne(const c10::SymNode& other) { - TORCH_LAZY_FN_COUNTER("xla::size_"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); auto p_other = dynamic_cast(other.get()); XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI"; XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI"; @@ -712,7 +712,7 @@ c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) { - TORCH_LAZY_FN_COUNTER("xla::size_"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); auto p_other = dynamic_cast(other.get()); XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI"; XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI"; @@ -726,7 +726,7 @@ c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::ge(const c10::SymNode& other) { - TORCH_LAZY_FN_COUNTER("xla::size_"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); auto p_other = dynamic_cast(other.get()); XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI"; XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI"; @@ -813,7 +813,7 @@ c10::SymNode XLASymNodeImpl::is_non_overlapping_and_dense( } c10::SymNode XLASymNodeImpl::clone() { - TORCH_LAZY_FN_COUNTER("xla::size_"); + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); return c10::make_intrusive(node(), pytype_); } @@ -891,4 +891,9 @@ int64_t XLATensor::GetHandle() const { } } +void XLATensor::MarkDynamicDimension(uint32_t dim) { + auto* xla_node = dynamic_cast(GetIrValue().node.get()); + xla_node->MarkDynamicDimension(dim); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 8564729bb71..f73aed5ce5f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -201,6 +201,7 @@ class XLATensor : public torch::lazy::LazyTensor { // Set logical_element_type which is visible to upstream PyTorch. void SetScalarType(c10::optional logical_element_type); + void MarkDynamicDimension(uint32_t dim); // We don't use the upstream shape to provide xla::shape. runtime::util::MaybeRef shape() const; diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 04bd60ce9a0..6322f052265 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -75,7 +75,9 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU // so we must manually update Autocast to AutocastCUDA on XLA:GPU. torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice(); - if (static_cast(current_device.type()) == XlaDeviceType::GPU) { + auto dev_type = static_cast(current_device.type()); + if (dev_type == XlaDeviceType::GPU || dev_type == XlaDeviceType::CUDA || + dev_type == XlaDeviceType::ROCM) { auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 6f2cfbf6b8e..20890d0be37 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -13,6 +13,7 @@ #include "torch_xla/csrc/LazyIr.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/lowering_context.h" @@ -165,7 +166,7 @@ MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor, const c10::optional& max) { XLA_CHECK(min || max) << "At least one of \'min\' or \'max\' must not be None"; - xla::PrimitiveType raw_element_type = TensorTypeToRawXlaType(tensor->dtype()); + xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype()); XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type); auto shape = tensor->shape(); return {XLAGraphExecutor::Get()->GetIrValueForScalar( @@ -525,7 +526,9 @@ void adam_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step, param->SetInPlaceIrValue(torch::lazy::Value(node, 1)); exp_avg->SetInPlaceIrValue(torch::lazy::Value(node, 2)); exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 3)); - max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4)); + if (amsgrad) { + max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4)); + } } std::vector user_computation( @@ -1025,9 +1028,9 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, bool input_is_float = xla::primitive_util::IsFloatingPointType(input_type); bool other_is_float = xla::primitive_util::IsFloatingPointType(other_type); if (input_is_float && !other_is_float) { - scalar_type = TensorTypeFromXlaType(input_type); + scalar_type = MaybeUpcastToHostTorchType(input_type); } else if (!input_is_float && other_is_float) { - scalar_type = TensorTypeFromXlaType(other_type); + scalar_type = MaybeUpcastToHostTorchType(other_type); } // We need to cast both input and other to float to perform true divide, floor // divide and trunc divide. @@ -1072,7 +1075,7 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) { xla::PrimitiveType input_type = input->shape().get().element_type(); bool input_is_float = xla::primitive_util::IsFloatingPointType(input_type); if (input_is_float) { - scalar_type = TensorTypeFromXlaType(input_type); + scalar_type = MaybeUpcastToHostTorchType(input_type); } torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar( @@ -1180,8 +1183,8 @@ XLATensorPtr eye(int64_t lines, int64_t cols, void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols) { out->SetIrValue( Identity(lines, cols >= 0 ? cols : lines, - GetDevicePrimitiveType(out->shape().get().element_type(), - &out->GetDevice()))); + MaybeDowncastToXlaDeviceType(out->shape().get().element_type(), + out->GetDevice()))); } void fill_(XLATensorPtr& input, const at::Scalar& value) { @@ -2054,8 +2057,9 @@ XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) { torch::lazy::NodePtr pow_node = Pow(input_node, exponent->GetIrValue()); at::ScalarType input_dtype = GetScalarType(input); at::ScalarType exp_dtype = exponent->dtype(); - at::ScalarType promoted_dtype = TensorTypeFromXlaType(XlaHelpers::PromoteType( - TensorTypeToRawXlaType(input_dtype), TensorTypeToRawXlaType(exp_dtype))); + at::ScalarType promoted_dtype = + MaybeUpcastToHostTorchType(XlaHelpers::PromoteType( + XlaTypeFromTorchType(input_dtype), XlaTypeFromTorchType(exp_dtype))); return exponent->CreateFrom(pow_node, promoted_dtype); } @@ -2063,6 +2067,15 @@ XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight) { return input->CreateFrom(Prelu(input->GetIrValue(), weight->GetIrValue())); } +std::tuple prelu_backward( + const XLATensorPtr& grad, const XLATensorPtr& input, + const XLATensorPtr& weight) { + torch::lazy::NodePtr node = PreluBackward( + grad->GetIrValue(), input->GetIrValue(), weight->GetIrValue()); + return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1))); +} + XLATensorPtr prod(const XLATensorPtr& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype) { @@ -2460,16 +2473,17 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector dims) { std::vector input_dimensions = torch_xla::runtime::util::ToVector( input_shape.get().dimensions()); - std::vector output_dimensions; + std::vector squeeze_dims; for (int64_t dim : dims) { - if (dim >= input_dimensions.size()) { - continue; - } int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_dimensions.size()); - output_dimensions = BuildSqueezedDimensions(input_dimensions, squeeze_dim); - input_dimensions = output_dimensions; + if (squeeze_dim >= input_dimensions.size()) { + continue; + } + squeeze_dims.push_back(squeeze_dim); } + std::vector output_dimensions = + BuildSqueezedDimensions(input_dimensions, squeeze_dims); return view(input, output_dimensions); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 88d6e8b4496..5a714170300 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -645,6 +645,10 @@ XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent); XLATensorPtr prelu(const XLATensorPtr& input, const XLATensorPtr& weight); +std::tuple prelu_backward( + const XLATensorPtr& grad_out, const XLATensorPtr& input, + const XLATensorPtr& weight); + XLATensorPtr prod(const XLATensorPtr& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index a419bd98b7e..e46bf7e022c 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -12,17 +12,18 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_sharding_util.h" @@ -34,121 +35,11 @@ namespace torch_xla { namespace { struct DataAsync { - std::vector source_tensors; + std::vector> source_tensors; std::vector async_datas; - std::vector handle_unlockers; + std::vector handle_unlockers; }; -bool ShouldUseBF16() { - bool use_bf16 = runtime::sys_util::GetEnvBool("XLA_USE_BF16", false); - if (use_bf16) { - TF_LOG(INFO) << "Using BF16 data type for floating point values"; - } - return use_bf16; -} - -bool ShouldUseF16() { - bool use_fp16 = runtime::sys_util::GetEnvBool("XLA_USE_FP16", false); - if (use_fp16) { - TF_LOG(INFO) << "Using F16 data type for floating point values"; - } - return use_fp16; -} - -bool ShouldDowncastToBF16() { - bool downcast_bf16 = - runtime::sys_util::GetEnvBool("XLA_DOWNCAST_BF16", false); - if (downcast_bf16) { - TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->BF16"; - } - return downcast_bf16; -} - -bool ShouldDowncastToF16() { - bool downcast_fp16 = - runtime::sys_util::GetEnvBool("XLA_DOWNCAST_FP16", false); - if (downcast_fp16) { - TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->FP16"; - } - return downcast_fp16; -} - -bool ShouldUse32BitLong() { - bool use_32bit_long = - runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false); - if (use_32bit_long) { - TF_LOG(INFO) << "Using 32bit integers for kLong values"; - } - return use_32bit_long; -} - -bool UseBF16() { - static bool use_bf16 = ShouldUseBF16(); - return use_bf16; -} - -bool UseF16() { - static bool use_fp16 = ShouldUseF16(); - return use_fp16; -} - -bool DowncastBF16() { - static bool downcast_bf16 = ShouldDowncastToBF16(); - return downcast_bf16; -} - -bool DowncastF16() { - static bool downcast_fp16 = ShouldDowncastToF16(); - return downcast_fp16; -} - -bool Use32BitLong() { - static bool use_32bit_long = ShouldUse32BitLong(); - return use_32bit_long; -} - -bool IsTpuDevice(XlaDeviceType hw_type) { - static bool spmd_device_is_tpu = - (hw_type == XlaDeviceType::SPMD) && - runtime::GetComputationClient()->GetDefaultDevice().find("TPU") == 0; - return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; -} - -xla::PrimitiveType XlaTypeFromTensorType( - at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) { - XlaDeviceType hw_type = static_cast(device.type()); - switch (scalar_type) { - case at::ScalarType::Double: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::F64 - : xla::PrimitiveType::F32; - case at::ScalarType::Float: - return xla::PrimitiveType::F32; - case at::ScalarType::BFloat16: - return xla::PrimitiveType::BF16; - case at::ScalarType::Half: - return xla::PrimitiveType::F16; - case at::ScalarType::Bool: - return xla::PrimitiveType::PRED; - case at::ScalarType::Byte: - return xla::PrimitiveType::U8; - case at::ScalarType::Char: - return xla::PrimitiveType::S8; - case at::ScalarType::Short: - return xla::PrimitiveType::S16; - case at::ScalarType::Int: - return xla::PrimitiveType::S32; - case at::ScalarType::Long: - return xla::PrimitiveType::S64; - case at::ScalarType::ComplexFloat: - return xla::PrimitiveType::C64; - case at::ScalarType::ComplexDouble: - return xla::PrimitiveType::C128; - default: - XLA_ERROR() << "Type not supported: " << scalar_type; - } -} - template struct Caster { template @@ -475,16 +366,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape, std::vector iter_dims = GetIterationDimensions(dest_shape); std::vector parts = CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front()); - auto mwait = std::make_shared(parts.size()); + absl::BlockingCounter counter(parts.size()); for (size_t i = 0; i < parts.size(); ++i) { auto copy_fn = [&, i]() { SlicedCopy(dest_shape.dimensions(), src_data, src_strides, dest_data, dest_strides, iter_dims, parts[i]); + counter.DecrementCount(); }; - runtime::env::ScheduleClosure( - runtime::util::MultiWait::Completer(mwait, std::move(copy_fn))); + thread::Schedule(std::move(copy_fn)); } - mwait->Wait(); + counter.Wait(); } } @@ -495,7 +386,8 @@ void TensorToBuffer(const at::Tensor& tensor, const xla::Shape& dest_shape, at::Tensor contiguous_tensor = tensor.contiguous(); xla::Shape src_shape = MakeTorchTensorLayout( XlaHelpers::I64List(contiguous_tensor.sizes()), /*dynamic_dimensions=*/{}, - XlaTypeFromTensorType(contiguous_tensor.type().scalarType(), device)); + MaybeDowncastToXlaDeviceType(contiguous_tensor.type().scalarType(), + device)); CopyTensors(contiguous_tensor.data_ptr(), src_shape, dest_buffer, dest_buffer_size, dest_shape); } @@ -587,15 +479,9 @@ torch::lazy::BackendDataPtr TensorToXlaData( sharding_spec); } - auto populate_fn = - [&](const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - - std::vector source_tensors; - source_tensors.emplace_back(shape, device.toString(), std::move(populate_fn)); + std::vector> source_tensors; + source_tensors.push_back( + std::make_shared(tensor, shape, device.toString())); auto handles = runtime::GetComputationClient()->TransferToServer(source_tensors); @@ -817,19 +703,12 @@ std::vector CreateTensorsData( return WrapXlaData(handles); } - std::vector source_tensors; + std::vector> source_tensors; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + tensors[i], std::move(shape), devices[i])); } return WrapXlaData( runtime::GetComputationClient()->TransferToServer(source_tensors)); @@ -848,7 +727,8 @@ std::vector CreateTensorsData( torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - std::vector source_tensors; // in + std::vector> + source_tensors; // in std::vector new_handles; // out if (static_cast(device.type()) == XlaDeviceType::SPMD) { // GetLocalDevices returns the list of local devices specified by their @@ -864,15 +744,8 @@ std::vector CreateTensorsData( new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shardings[i])); } else { - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + tensors[i], std::move(shape), devices[i])); new_handles = runtime::GetComputationClient()->TransferToServer(source_tensors); } @@ -889,7 +762,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, auto dimensions = XlaHelpers::I64List(tensor.sizes()); computed_shape = MakeTorchTensorLayout( dimensions, /*dynamic_dimensions=*/{}, - XlaTypeFromTensorType(tensor.type().scalarType(), xla_device)); + MaybeDowncastToXlaDeviceType(tensor.type().scalarType(), xla_device)); shape = &computed_shape; } xla::Literal literal(*shape); @@ -898,12 +771,33 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, return literal; } -std::vector XlaDataToTensors( - absl::Span xla_data, - at::ScalarType dest_element_type) { +std::vector ReleaseGilAndTransferData( + absl::Span xla_data) { + // HACK: This method may be called outside of python (mainly in C++ tests) or + // when the GIL is already released, so we must check both cases here. If + // possible, prefer to release the GIL in the python bindings before copying + // this pattern. + PyThreadState* save = nullptr; + // TODO(wcromar): Remove this setting when we are more confident + static const bool release_gil = + runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); + if (release_gil && Py_IsInitialized() && PyGILState_Check()) { + save = PyEval_SaveThread(); + } std::vector literals = runtime::GetComputationClient()->TransferFromServer( UnwrapXlaData(xla_data)); + if (save) { + PyEval_RestoreThread(save); + } + + return literals; +} + +std::vector XlaDataToTensors( + absl::Span xla_data, + at::ScalarType dest_element_type) { + std::vector literals = ReleaseGilAndTransferData(xla_data); std::vector tensors; tensors.reserve(literals.size()); for (auto& literal : literals) { @@ -984,148 +878,16 @@ xla::Shape CreateComputationShapeFromTensor( static_cast(xla_device.type())); } -at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) { - switch (xla_type) { - case xla::PrimitiveType::BF16: - return UseBF16() || DowncastBF16() ? at::ScalarType::Float - : at::ScalarType::BFloat16; - case xla::PrimitiveType::F16: - return UseF16() || DowncastF16() ? at::ScalarType::Float - : at::ScalarType::Half; - case xla::PrimitiveType::F32: - return DowncastBF16() || DowncastF16() ? at::ScalarType::Double - : at::ScalarType::Float; - case xla::PrimitiveType::F64: - return at::ScalarType::Double; - case xla::PrimitiveType::PRED: - return at::ScalarType::Bool; - case xla::PrimitiveType::U8: - return at::ScalarType::Byte; - case xla::PrimitiveType::S8: - return at::ScalarType::Char; - case xla::PrimitiveType::S16: - case xla::PrimitiveType::U16: - return at::ScalarType::Short; - case xla::PrimitiveType::S32: - case xla::PrimitiveType::U32: - return at::ScalarType::Int; - case xla::PrimitiveType::S64: - case xla::PrimitiveType::U64: - return at::ScalarType::Long; - case xla::PrimitiveType::C64: - return at::ScalarType::ComplexFloat; - case xla::PrimitiveType::C128: - return at::ScalarType::ComplexDouble; - default: - XLA_ERROR() << "XLA type not supported: " << xla_type; - } -} - -xla::PrimitiveType TensorTypeToRawXlaType(at::ScalarType scalar_type) { - switch (scalar_type) { - case at::ScalarType::Double: - return xla::PrimitiveType::F64; - case at::ScalarType::Float: - return xla::PrimitiveType::F32; - case at::ScalarType::BFloat16: - return xla::PrimitiveType::BF16; - case at::ScalarType::Half: - return xla::PrimitiveType::F16; - case at::ScalarType::Bool: - return xla::PrimitiveType::PRED; - case at::ScalarType::Byte: - return xla::PrimitiveType::U8; - case at::ScalarType::Char: - return xla::PrimitiveType::S8; - case at::ScalarType::Short: - return xla::PrimitiveType::S16; - case at::ScalarType::Int: - return xla::PrimitiveType::S32; - case at::ScalarType::Long: - return xla::PrimitiveType::S64; - case at::ScalarType::ComplexFloat: - return xla::PrimitiveType::C64; - case at::ScalarType::ComplexDouble: - return xla::PrimitiveType::C128; - default: - XLA_ERROR() << "Type not supported: " << scalar_type; - } -} - -xla::PrimitiveType GetDevicePrimitiveType( - xla::PrimitiveType type, const torch::lazy::BackendDevice* device) { - torch::lazy::BackendDevice xla_device = bridge::GetDeviceOrCurrent(device); - XlaDeviceType hw_type = static_cast(xla_device.type()); - switch (type) { - case xla::PrimitiveType::F64: - if (UseF16()) { - return xla::PrimitiveType::F16; - } - if (UseBF16()) { - return xla::PrimitiveType::BF16; - } - if (DowncastBF16() || DowncastF16()) { - return xla::PrimitiveType::F32; - } - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::F64 - : xla::PrimitiveType::F32; - case xla::PrimitiveType::F32: - if (UseF16() || DowncastF16()) { - return xla::PrimitiveType::F16; - } - return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 - : xla::PrimitiveType::F32; - case xla::PrimitiveType::U16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::U16 - : xla::PrimitiveType::U32; - case xla::PrimitiveType::S16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::S16 - : xla::PrimitiveType::S32; - case xla::PrimitiveType::S64: - return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; - case xla::PrimitiveType::U64: - return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; - case xla::PrimitiveType::C128: - return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128 - : xla::PrimitiveType::C64; - default: - return type; - } +xla::PrimitiveType GetXlaPrimitiveTypeForCurrentDevice( + xla::PrimitiveType xla_type) { + torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); + return MaybeDowncastToXlaDeviceType(xla_type, xla_device); } xla::PrimitiveType MakeXlaPrimitiveType( at::ScalarType scalar_type, const torch::lazy::BackendDevice* device) { - switch (scalar_type) { - case at::ScalarType::Double: - return GetDevicePrimitiveType(xla::PrimitiveType::F64, device); - case at::ScalarType::Float: - return GetDevicePrimitiveType(xla::PrimitiveType::F32, device); - case at::ScalarType::BFloat16: - return GetDevicePrimitiveType(xla::PrimitiveType::BF16, device); - case at::ScalarType::Half: - return GetDevicePrimitiveType(xla::PrimitiveType::F16, device); - case at::ScalarType::Bool: - return GetDevicePrimitiveType(xla::PrimitiveType::PRED, device); - case at::ScalarType::Byte: - return GetDevicePrimitiveType(xla::PrimitiveType::U8, device); - case at::ScalarType::Char: - return GetDevicePrimitiveType(xla::PrimitiveType::S8, device); - case at::ScalarType::Short: - return GetDevicePrimitiveType(xla::PrimitiveType::S16, device); - case at::ScalarType::Int: - return GetDevicePrimitiveType(xla::PrimitiveType::S32, device); - case at::ScalarType::Long: - return GetDevicePrimitiveType(xla::PrimitiveType::S64, device); - case at::ScalarType::ComplexFloat: - return GetDevicePrimitiveType(xla::PrimitiveType::C64, device); - case at::ScalarType::ComplexDouble: - return GetDevicePrimitiveType(xla::PrimitiveType::C128, device); - default: - XLA_ERROR() << "Type not supported: " << scalar_type; - } + torch::lazy::BackendDevice xla_device = bridge::GetDeviceOrCurrent(device); + return MaybeDowncastToXlaDeviceType(scalar_type, xla_device); } xla::Shape MakeXlaShapeFromLazyShape(torch::lazy::Shape shape, @@ -1144,7 +906,7 @@ bool RequiresRawTypeCasting(at::ScalarType scalar_type, case at::ScalarType::Char: case at::ScalarType::Short: return MakeXlaPrimitiveType(scalar_type, device) != - TensorTypeToRawXlaType(scalar_type); + XlaTypeFromTorchType(scalar_type); default: return false; } diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index a0aaadbc75f..81b4cd9a565 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -25,6 +25,12 @@ std::vector ComputeShapeStrides(const xla::Shape& shape); at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, at::ScalarType dest_element_type); +// Execution and data transfer are async in PJRT, so TransferFromServer may +// block until `DataPtr`s are ready. Release the GIL so other threads can +// proceed and unblock any transfers or collective computations. +std::vector ReleaseGilAndTransferData( + absl::Span xla_data); + // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice std::vector XlaDataToTensors( absl::Span xla_data, @@ -80,14 +86,9 @@ void PopulateTensorBuffer(const at::Tensor& tensor, xla::Shape CreateComputationShapeFromTensor( const at::Tensor& tensor, const torch::lazy::BackendDevice* device); -at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type); - -xla::PrimitiveType TensorTypeToRawXlaType(at::ScalarType scalar_type); - -// Maps an XLA type to the one which can be used on the given device (or the -// default device, id device is nullptr). -xla::PrimitiveType GetDevicePrimitiveType( - xla::PrimitiveType type, const torch::lazy::BackendDevice* device); +// Make a compatible dtype for the current device +xla::PrimitiveType GetXlaPrimitiveTypeForCurrentDevice( + xla::PrimitiveType xla_type); // Converts the given scalar type to an XLA primitive type. xla::PrimitiveType MakeXlaPrimitiveType( diff --git a/torch_xla/csrc/thread_pool.cc b/torch_xla/csrc/thread_pool.cc new file mode 100644 index 00000000000..e440afce7bd --- /dev/null +++ b/torch_xla/csrc/thread_pool.cc @@ -0,0 +1,21 @@ +#include "torch_xla/csrc/thread_pool.h" + +#include + +#include "torch_xla/csrc/runtime/sys_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" + +namespace torch_xla { +namespace thread { + +void Schedule(std::function fn) { + static size_t num_threads = torch_xla::runtime::sys_util::GetEnvInt( + "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); + static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", + num_threads); + pool.Schedule(std::move(fn)); +} + +} // namespace thread +} // namespace torch_xla diff --git a/torch_xla/csrc/thread_pool.h b/torch_xla/csrc/thread_pool.h new file mode 100644 index 00000000000..22074e6886f --- /dev/null +++ b/torch_xla/csrc/thread_pool.h @@ -0,0 +1,16 @@ +#ifndef XLA_CLIENT_THREAD_POOL_H_ +#define XLA_CLIENT_THREAD_POOL_H_ + +#include + +namespace torch_xla { +namespace thread { + +// Schedules a closure to be run. The closure should not block waiting for other +// events. +void Schedule(std::function fn); + +} // namespace thread +} // namespace torch_xla + +#endif // XLA_CLIENT_THREAD_POOL_H_ diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp index 2148478006f..1d5e3616643 100644 --- a/torch_xla/csrc/torch_util.cpp +++ b/torch_xla/csrc/torch_util.cpp @@ -78,9 +78,7 @@ at::Tensor MaybeWrapTensorToFunctional(const at::Tensor& tensor) { namespace torch { namespace lazy { torch::lazy::hash_t Hash(const xla::Shape& shape) { - auto shape_hash = torch_xla::runtime::util::ShapeHash(shape); - return c10::uint128(absl::Uint128High64(shape_hash), - absl::Uint128Low64(shape_hash)); + return torch_xla::runtime::util::ShapeHash(shape); } } // namespace lazy } // namespace torch diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 827d72ff6c7..0033176a172 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -27,6 +28,7 @@ #include "absl/strings/str_join.h" #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/layout_manager.h" @@ -46,11 +48,10 @@ #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" -#include "torch_xla/csrc/runtime/thread_pool.h" -#include "torch_xla/csrc/runtime/unique.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_sharding_util.h" @@ -217,7 +218,7 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device) { torch::lazy::BackendDataPtr data = - GetDeviceData(value, TensorTypeFromXlaType(type), device); + GetDeviceData(value, MaybeUpcastToHostTorchType(type), device); data->SetInfo( std::make_shared( /*tensor_id=*/-1, /*read_only=*/true)); @@ -410,7 +411,21 @@ std::vector XLAGraphExecutor::GetTensors( std::vector* tensors) { TF_VLOG(4) << "Trying to get the value of " << tensors->size() << " tensor(s)"; - return GetTensorsFused(tensors); + SyncTensorsConfig config; + config.force_ltc_data = false; + auto async = SyncTensorsGraphInternal(tensors, {}, config); + if (async != nullptr) { + async->mwait.Wait(); + } + std::vector tensors_data = GatherTensorsXlaData( + *tensors, async != nullptr ? async->indices : absl::Span(), + async != nullptr ? async->tensors_data + : absl::Span()); + + std::vector literals = ReleaseGilAndTransferData(tensors_data); + + return FetchTensors(tensors, literals, + async != nullptr ? &async->indices : nullptr); } torch::lazy::hash_t XLAGraphExecutor::GetGraphHash( @@ -501,7 +516,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { tsl::profiler::TraceMe activity("CollectSyncTensors", tsl::profiler::TraceMeLevel::kInfo); - runtime::util::Unique unique_device; + torch::lazy::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { unique_device.set(tensors[i]->GetDevice()); } @@ -596,6 +611,10 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier", tsl::profiler::TraceMeLevel::kInfo); MaybeDumpGraph("dynamo", hash); + if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) { + DebugUtil::analyze_graph_execution_python_frame( + /*from_dynamo_executation=*/true); + } auto cachedComputation = XLAGraphExecutor::Get()->GetComputationCache()->Get(hash); TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash) @@ -738,7 +757,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( } }; - runtime::env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn))); + thread::Schedule(async->mwait.Completer(std::move(syncfn))); return placeholders; } @@ -798,44 +817,6 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( return WrapXlaData(result_data); } -std::vector XLAGraphExecutor::GetTensorsFused( - std::vector* tensors) { - SyncTensorsConfig config; - config.force_ltc_data = false; - auto async = SyncTensorsGraphInternal(tensors, {}, config); - if (async != nullptr) { - async->mwait.Wait(); - } - std::vector tensors_data = GatherTensorsXlaData( - *tensors, async != nullptr ? async->indices : absl::Span(), - async != nullptr ? async->tensors_data - : absl::Span()); - - // Execution is async in PJRT, so TransferFromServer may block until execution - // completes. Release the GIL so other threads can proceed and unblock any - // collective computations. - // HACK: This method may be called outside of python (mainly in C++ tests) or - // when the GIL is already released, so we must check both cases here. If - // possible, prefer to release the GIL in the python bindings before copying - // this pattern. - PyThreadState* save = nullptr; - // TODO(wcromar): Remove this setting when we are more confident - static const bool release_gil = - runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); - if (release_gil && Py_IsInitialized() && PyGILState_Check()) { - save = PyEval_SaveThread(); - } - std::vector literals = - runtime::GetComputationClient()->TransferFromServer( - UnwrapXlaData(tensors_data)); - if (save) { - PyEval_RestoreThread(save); - } - - return FetchTensors(tensors, literals, - async != nullptr ? &async->indices : nullptr); -} - std::vector XLAGraphExecutor::GatherTensorsXlaData( const std::vector& tensors, absl::Span indices, absl::Span tensors_data) { @@ -1048,7 +1029,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( } }; - runtime::env::ScheduleIoClosure(async->mwait.Completer(std::move(syncfn))); + thread::Schedule(async->mwait.Completer(std::move(syncfn))); return async; } @@ -1327,6 +1308,9 @@ XLAGraphExecutor::SyncTensorsGraphInternal( const SyncTensorsConfig& config, bool warm_up_cache_only) { tsl::profiler::TraceMe activity("SyncTensorsGraphInternal", tsl::profiler::TraceMeLevel::kInfo); + if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) { + DebugUtil::analyze_graph_execution_python_frame(); + } SyncTensorCollection coll = CollectSyncTensors(*tensors, config); if (coll.indices.empty()) { // Enure previous execution is complete before exiting this diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 0798853ecf0..90eec4012d6 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -10,16 +10,15 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch_xla/csrc/cross_replica_reduces.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/runtime/async_task.h" #include "torch_xla/csrc/runtime/cache.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/torch_util.h" @@ -258,9 +257,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // Override to enable SPMD. void TensorCollectionBarrier(SyncTensorCollection* coll) final; - // We don't use upstream GetTensorsFused as we have xla::Literal. - std::vector GetTensorsFused(std::vector* tensors); - // Gathers the XLA device data for all the input tensors, after an // asynchronous operation. // TODO(alanwaketan): Reuse the upstream one once Functionalization is done. diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 8d87ac02e1a..1faa31ce267 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -350,9 +350,9 @@ std::vector CreateKthValue(xla::XlaOp input, int64_t k, int64_t dim, indices = XlaHelpers::DynamicReshape(indices, reshape_sizes); } // aten::kthvalue() wants Long tensors as indices. - return {values, xla::ConvertElementType( - indices, GetDevicePrimitiveType(xla::PrimitiveType::S64, - /*device=*/nullptr))}; + return {values, + xla::ConvertElementType(indices, GetXlaPrimitiveTypeForCurrentDevice( + xla::PrimitiveType::S64))}; } std::vector CreateTopK(xla::XlaOp input, int64_t k, int64_t dim, @@ -383,9 +383,9 @@ std::vector CreateTopK(xla::XlaOp input, int64_t k, int64_t dim, xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); // aten::topk() wants Long tensors as indices. - return {values, xla::ConvertElementType( - indices, GetDevicePrimitiveType(xla::PrimitiveType::S64, - /*device=*/nullptr))}; + return {values, + xla::ConvertElementType(indices, GetXlaPrimitiveTypeForCurrentDevice( + xla::PrimitiveType::S64))}; } xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) { @@ -1013,10 +1013,11 @@ std::vector BuildAdamOptimizerStep( xla::XlaOp new_exp_avg_sq = xla::Select(found_inf_cond, exp_avg_sq, exp_avg_sq * beta2 + new_grad * new_grad * (one - beta2)); - xla::XlaOp new_max_exp_avg_sq = xla::Select( - found_inf_cond, max_exp_avg_sq, xla::Max(max_exp_avg_sq, new_exp_avg_sq)); + xla::XlaOp new_max_exp_avg_sq; xla::XlaOp denom; if (use_amsgrad) { + new_max_exp_avg_sq = xla::Select(found_inf_cond, max_exp_avg_sq, + xla::Max(max_exp_avg_sq, new_exp_avg_sq)); denom = xla::Sqrt(new_max_exp_avg_sq) / xla::Sqrt(bias_correction2) + eps; } else { denom = xla::Sqrt(new_exp_avg_sq) / xla::Sqrt(bias_correction2) + eps; @@ -1031,7 +1032,9 @@ std::vector BuildAdamOptimizerStep( results.push_back(new_param); results.push_back(new_exp_avg); results.push_back(new_exp_avg_sq); - results.push_back(new_max_exp_avg_sq); + if (use_amsgrad) { + results.push_back(new_max_exp_avg_sq); + } return results; } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index f7da463fb64..36fc1810d8b 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -5,16 +5,21 @@ #include #include +#include "absl/synchronization/blocking_counter.h" #include "torch/csrc/lazy/core/ir_util.h" +#include "torch_xla/csrc/aten_autograd_ops.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/runtime.h" -#include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/thread_pool.h" +#include "torch_xla/csrc/xla_graph_executor.h" #include "tsl/profiler/lib/traceme.h" #include "xla/execution_options_util.h" #include "xla/hlo/ir/hlo_module.h" @@ -321,7 +326,7 @@ ShardingUtil::InputHandler( // the first local index with the first global device ordinal. auto device_index = build_index_map(devices); - auto mwait = std::make_shared(devices.size()); + absl::BlockingCounter counter(devices.size()); for (int i = 0; i < devices.size(); i++) { auto argument_setter = [&, i]() { @@ -334,11 +339,11 @@ ShardingUtil::InputHandler( int device_i = device_index[global_ordinal]; arguments_by_device[device_i][argument_i] = shard; } + counter.DecrementCount(); }; - runtime::env::ScheduleIoClosure( - runtime::util::MultiWait::Completer(mwait, std::move(argument_setter))); + thread::Schedule(std::move(argument_setter)); } - mwait->Wait(); + counter.Wait(); return arguments_by_device; } @@ -358,7 +363,7 @@ std::vector ShardingUtil::OutputHandler( // Reshards replicated output if `sharding` is present. std::vector tensors = XlaDataToTensors( {sharded_results[0][i]}, - TensorTypeFromXlaType(sharding->shape.element_type())); + MaybeUpcastToHostTorchType(sharding->shape.element_type())); outputs.push_back( std::dynamic_pointer_cast( CreateTensorsData( @@ -706,11 +711,12 @@ void ShardingUtil::PrepareOutputShardingPropagation( } runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( - std::vector& local_shards, std::vector& devices, + const std::vector& local_shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; - std::vector source_tensors; + std::vector> source_tensors; xla::Shape global_shape; xla::OpSharding sharding; if (sharding_spec == nullptr) { @@ -727,18 +733,142 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = CreateComputationShapeFromTensor(local_shards[j], &shard_device); - auto populate_fn = - [&, j, shard_device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(local_shards[j], source_tensor.shape, - dest_buffer, dest_buffer_size, shard_device); - }; - source_tensors.emplace_back(shard_shape, devices[j], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + local_shards[j], shard_shape, devices[j])); } return runtime::GetComputationClient()->TransferShardsToServer( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } +void ShardingUtil::xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding) { + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; + } + } + + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); +} + +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type) { + py::list tile_assignment_py = py::list(); + for (int i = 0; i < tile_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : tile_assignment[i].get().toIntList()) { + pylist.append(t); + } + tile_assignment_py.append(pylist); + } + + py::list group_assignment_py = py::list(); + for (int i = 0; i < group_assignment.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : group_assignment[i].get().toIntList()) { + pylist.append(t); + } + group_assignment_py.append(pylist); + } + + py::list replication_groups_py = py::list(); + for (int i = 0; i < replication_groups.size(); i++) { + py::list pylist = py::list(); + for (int64_t t : replication_groups[i].get().toIntList()) { + pylist.append(t); + } + replication_groups_py.append(pylist); + } + + xla::OpSharding op_sharding = ShardingUtil::CreateOpSharding( + tile_assignment_py, group_assignment_py, replication_groups_py, + ShardingUtil::ShardingType(sharding_type)); + + ShardingUtil::xla_mark_sharding(input, op_sharding); +} + +// Macro for defining a function that will be run at static initialization time +// to define a library of operators in the namespace. Used to define a new set +// of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, + TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op))); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 4a595f4e99b..697f320f575 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -15,7 +15,7 @@ namespace torch_xla { class ShardingUtil { public: - // This maps to `torch_xla.experimental.xla_sharding.ShardingType` enum type. + // This maps to `torch_xla.distributed.spmd.ShardingType` enum type. enum ShardingType { REPLICATED = 0, MAXIMAL = 1, @@ -147,10 +147,19 @@ class ShardingUtil { // Transfers the individual shards to the devices and returns a DataPtr for // the PjRtShardedData wrapping the shards. static runtime::ComputationClient::DataPtr CreateShardedData( - std::vector& shards, std::vector& devices, + const std::vector& shards, + const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); + + static void xla_mark_sharding(const at::Tensor& input, + xla::OpSharding sharding); }; +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_ diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py new file mode 100644 index 00000000000..3cd50e1e7c0 --- /dev/null +++ b/torch_xla/distributed/spmd/__init__.py @@ -0,0 +1,12 @@ +from .xla_sharded_tensor import XLAShard, XLAShardedTensor +from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec, + XLAPatchedLinear, mark_sharding, clear_sharding, + wrap_if_sharded, xla_patched_nn_linear_forward) +from .api import xla_distribute_tensor, xla_distribute_module + +__all__ = [ + "XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType", + "ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding", + "wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module", + "xla_patched_nn_linear_forward" +] diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py new file mode 100644 index 00000000000..bea4415db57 --- /dev/null +++ b/torch_xla/distributed/spmd/api.py @@ -0,0 +1,182 @@ +import logging +import os +from functools import wraps +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +import torch + +import torch.nn as nn +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import Placement, Replicate + +import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401 +import torch_xla.runtime as xr # type:ignore[import] +from torch_xla.distributed.spmd import ( # type:ignore[import] + XLAShardedTensor, mark_sharding, Mesh, ShardingType, +) + +log = logging.getLogger(__name__) + + +# wrapper to check xla test requirements +def with_xla(func: Callable) -> Callable: + assert func is not None + + @wraps(func) # pyre-ignore[6] + def wrapper( + self, + *args: Tuple[object], + **kwargs: Dict[str, Any] # type: ignore[misc] + ) -> None: + os.environ["XLA_USE_SPMD"] = "1" + return func(self, *args, **kwargs) # type: ignore[misc] + + return wrapper + + +@with_xla +def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh": + """ + Convert DTensor `dt_mesh` to XLAShardedTensor `partition_spec`. + + Example (1x4 logical device mesh topology): + ``` + dt_mesh = DeviceMesh("xla", [[1, 2, 3, 4]]) + dt_mesh.shape + >> torch.Size([1, 4]) + + mesh = convert_to_xla_mesh(dt_mesh) + mesh_shape + >> [1, 4] + ``` + """ + assert dt_mesh.size() == xr.global_runtime_device_count() + return Mesh(dt_mesh.mesh.flatten(), tuple(dt_mesh.mesh.size()), + dt_mesh.mesh_dim_names) + + +@with_xla +def convert_to_xla_partition_spec( + tensor: torch.Tensor, + placements: Sequence[Placement]) -> Tuple[Union[Tuple, int, None]]: + """ + Convert DTensor `placements` to XLAShardedTensor `partitoin_spec`. + This supports Shard and Replicate Placement types. + + Example: + ``` + # Mesh partitioning, 1/4-th of the input with replicated overlaps. + # The first input tensor dimension is sharded across the second mesh + # dimension, and the rest is replicated over the first mesh dimension. + t = torch.randn(4, 8, 8) + dt_mesh = DeviceMesh("xla", torch.arange(8).reshape(2,4)) + placements = [Replicate(), Shard(0)] + my_dtensor = distribute_tensor(t, dt_mesh, placements) + + # `placements = [Replicate(), Shard(0)]` describes sharding per mesh dim, + # and this is equivalent to `partition_spec = (1, None, None)` which is + # sharding per input tensor dimension. + partition_spec = convert_to_xla_partition_spec(t, placements) + >> (1, None, None) + ``` + """ + # per tensor dimension sharding + sharding_spec = [None] * len(tensor.shape) + for mesh_idx, spec in enumerate(placements): + if spec.is_shard(): # type:ignore[truthy-function] + # mesh_idx to tensor_idx (spec.dim) + tensor_idx = spec.dim # type:ignore[attr-defined] + sharding_spec[tensor_idx] = mesh_idx # type:ignore[call-overload] + elif spec.is_replicate(): + # spec.dim is already set to None by default + continue + else: + raise ValueError(f"Unsupported placement type: {type(spec).__name__}") + return tuple(sharding_spec) # type:ignore[return-value] + + +@with_xla +def xla_distribute_tensor( + tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Optional[Sequence[Placement]] = None, +) -> "XLAShardedTensor": + """ + Distribute a torch.Tensor to the `device_mesh` according to the `placements` + specified. The rank of `device_mesh` and `placements` must be the same. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use `torch.chunk` + semantic to shard the tensor and scatter the shards. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as `device_mesh.ndim`. If not specified, we will + by default replicate the tensor across the `device_mesh` from the + first rank of each dimension of the `device_mesh`. + + Returns: + A :class:`XLAShardedTensor` object + + .. note:: We return a XLAShardedTensor with a global view and access to local shards. + The successive ops would be programmed as if on a single-device and without calling + any explicit collective ops. The actual sharded computation on the sharding annotated tensor + happens lazily, is transparent to the user. In the future, we will introduce + a new DTensor type for this kind of programming-mode (single-controller) and return. + """ + # device_mesh is not optional in xla_distribute_tensor + dt_mesh = device_mesh + assert dt_mesh.device_type == "xla" + + # convert to XLA device mesh + xla_mesh = convert_to_xla_mesh(dt_mesh) + assert xla_mesh.mesh_shape == tuple(dt_mesh.mesh.size()) + + # convert tensor to the corresponding device type if it's not in that device type + if not tensor.is_meta: + tensor = tensor.to(dt_mesh.device_type) + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(dt_mesh.ndim)] + assert (len(placements) == dt_mesh.ndim + ), "`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {dt_mesh.ndim}." + # convert placements to xla partition spec + partition_spec = convert_to_xla_partition_spec(tensor, placements) + assert len(tensor.shape) == len( + partition_spec + ), "`partition_spec` from `placements` must have the same length as `tensor.length`! " + f"Found tensor shape length: {len(tensor.shape)}, and partition_spec length: {len(partition_spec)}." + + global_tensor = tensor + if type(tensor).__name__ == "DTensor": + raise ValueError( + "Cannot distribute a DTensor with local tensor on xla devices." + "The input tensor must be global.") + if type(tensor).__name__ == "XLAShardedTensor": + sharding_type = tensor.sharding_type # type:ignore[attr-defined] + assert ( + sharding_type is None or sharding_type == ShardingType.REPLICATED + ), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. " + "Clear the existing sharding annotation first, by callling torch_xla.distributed.spmd.clear_sharding API." + global_tensor = tensor.global_tensor # type:ignore[attr-defined] + assert global_tensor is not None, "distributing a tensor should not be None" + + # Annotates sharding and returns an XLAShardedTensor + xla_tensor = mark_sharding(global_tensor, xla_mesh, partition_spec) + return xla_tensor + + +@with_xla +def xla_distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[..., None]] = None, + output_fn: Optional[Callable[..., None]] = None, +) -> nn.Module: + raise NotImplementedError diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py new file mode 100644 index 00000000000..2945502dcc2 --- /dev/null +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -0,0 +1,171 @@ +import torch +from torch.utils._pytree import tree_map +import torch_xla + +from dataclasses import dataclass +from typing import List, Tuple, Iterator, Union +import contextlib +import collections + + +@dataclass +class XLAShard: + # A snapshot of the shard data from the time of XLAShard creation. + data: torch.Tensor + + # The indices of the shard into the global tensor. If the tensor is replicated + # across local devices, the value of `indices` is Ellipsis. Otherwise, it is a + # list of the index slices across each dimension. + # The indices do not reflect padding, since the padding does not exist on the + # global tensor. + indices: Union[type(Ellipsis), List[slice]] + + # The device this shard's data originated from. + shard_device: str + + # The replica this shard belongs to, as determined by the sharding. The + # replica is determined differently for each sharding type: + # - TILED: Since the tensor isn't replicated, replica_id is always 0. + # - PARTIAL: replica_id is taken from the OpSharding and is a value in + # the range [0, num_replica). + # - REPLICATED: Since the tensor is fully replicated, replica_id is the + # device's global ordinal. + replica_id: int + + @property + def unpadded_data(self) -> torch.Tensor: + ''' Returns a copy of `data` with padding removed ''' + unpadded_indices = self.indices + # Replicated data has Ellipsis as indices + if self.indices != Ellipsis: + unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices] + return self.data[unpadded_indices] + + @unpadded_data.setter + def unpadded_data(self, t: torch.Tensor): + unpadded_indices = self.indices + if self.indices != Ellipsis: + unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices] + self.data[unpadded_indices] = t + + +@contextlib.contextmanager +def no_dispatch() -> Iterator[None]: + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + try: + yield + finally: + del guard + + +class XLAShardedTensor(torch.Tensor): + """ + A wrapper around `torch.Tensor` with sharding annotation + for XLA SPMD auto-sharding. The wrapped tensors are unwrapped + for IR tracing and converted to HLO graph with sharding annotations; + XLA SPMDPartitioner takes a pass, propagating and injecting collectives + to the graph before compilation. + """ + + # XLAShardedTensor behaves like a unpartitioned, + # combined tensor on the host machine. When user annotates, + # this is simply set to the input tensor. When an XLA partitioned + # output tensor returns (or sharding propagated intermediate tensors) + # as XLAShardedTensor, the backend gathers global data across devices + # and materialize and set `global_tensor` on the host; the actual device + # data still remain on individual device as sharded or replicated. + # Note: we should drop this reference, and force all gather on each access. + global_tensor: torch.Tensor + # A logical device topology, each element describes + # a number of devices in the corresponding axis. + # NOTE: we could use more specific device-rank mapping, e.g., ShardingSpec, + # if needed. The change shouldn't be difficult, or create another constructor. + mesh_shape: Tuple[int] # TODO: create a wrapper for named axes + # Specifies how each input rank is sharded (index to mesh_shape) + # or replicated (None). For example, we can shard an 8x10 tensor + # 4-way row-wise, and replicate column-wise. + # >> input = torch.randn(8, 10) + # >> mesh_shape = (4, 2) + # >> assert np.prod(mesh_shape) == len(xm.get_xla_supported_devices()) + # >> partition_spec = (0, None) + # >> assert len(input.shape) == len(partition_spec) + partition_spec: Tuple[int, None] + + __slots__ = ['global_tensor'] + + @staticmethod + def __new__(cls, elem: torch.Tensor, *args, **kwargs): + # TODO(yeounoh) wrapper can take different arguments + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=kwargs.get("requires_grad", False)) + r.global_tensor = elem.detach() if r.requires_grad else elem + return r + + # Shards on the devices are materialized/available after the lazy + # execution of the partitioned HLO graph. Each XLAShard points + # to torch.Tensor. The shards represent a snapshot on CPU, detached + # from the global tensor. The shard data will contain any padding + # which results from the sharding. + @property + def local_shards(self) -> List[XLAShard]: + shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor) + replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( + self.global_tensor) + zipped = zip(shards, replica_and_indices, devices) + return [ + XLAShard(data, indices, dev, replica) + for data, (replica, indices), dev in zipped + ] + + # Load the given list of local shards into the underlying tensor's data + # on the local devices. + def load_local_shards_(self, shards: List[XLAShard]): + data = [s.data for s in shards] + devices = [s.shard_device for s in shards] + torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices) + + @property + def sharding_spec(self): + return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor) + + @property + def sharding_type(self) -> 'ShardingType': + from torch_xla.distributed.spmd import ShardingType + sharding_type = torch_xla._XLAC._get_xla_sharding_type(self.global_tensor) + return ShardingType(sharding_type) + + def __repr__(self): + if not hasattr(self, "global_tensor"): + # materialize a copy of sharded global_tensnor and keep the actual data + # sharded on the XLA devices. + return str(self.cpu()) + return f"XLAShardedTensor({self.global_tensor})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + """ + The dispatcher allows the unwrapped torch.Tensor to re-dispatched to the + `xla` backend as XlaTensor, and the XlaTensor with an associated sharding spec + to be received and wrapped as XLAShardedTensor. + """ + + def unwrap(elem): + return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem + + def wrap(elem): + return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem + + # no_dispatch is only needed if you use enable_python_mode. + # It prevents infinite recursion. + with no_dispatch(): + # re-dispatch to C++ + rs = tree_map(wrap, + func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + return rs diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py new file mode 100644 index 00000000000..d96531a5616 --- /dev/null +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -0,0 +1,642 @@ +import os +from collections import OrderedDict, defaultdict +from dataclasses import dataclass, field +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard +import torch_xla.runtime as xr + +import numpy as np +import functools +import itertools +from typing import Tuple, Union, List, Sequence, Any, Optional, Set +from enum import IntEnum + + +class Mesh: + """Describe the logical XLA device topology mesh and the underlying resources. + + Args: + device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped + to an `mesh_shape` array, filling the elements using C-like index order. + + mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape + of the device mesh, and each element describes the number of devices in + the corresponding axis. + + axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions + of the `devices` argument. Its length should match the rank of `devices`. + + Example: + —------------------------------ + mesh_shape = (4, 2) + num_devices = len(xm.get_xla_supported_devices()) + device_ids = np.array(range(num_devices)) + mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh.get_logical_mesh() + >> array([[0, 1], + [2, 3], + [4, 5], + [6, 7]]) + mesh.shape() + >> OrderedDict([('x', 4), ('y', 2)]) + """ + + device_ids: np.ndarray + mesh_shape: Tuple[int, ...] + axis_names: Tuple[str, ...] + + def __init__(self, + device_ids: Union[np.ndarray, List], + mesh_shape: Tuple[int, ...], + axis_names: Tuple[str, ...] = None): + if not isinstance(device_ids, np.ndarray): + device_ids = np.array(device_ids) + assert (axis_names is None) or (len(mesh_shape) == len(axis_names)) + assert axis_names is None or (len(set(axis_names)) == len(axis_names)) + assert (len(device_ids) == np.prod(mesh_shape)) + assert len(device_ids) == len(np.unique(device_ids)) + self.device_ids = device_ids + self.mesh_shape = mesh_shape + self.axis_names = axis_names + assert all(d < self.size() for d in device_ids) + + def size(self): + return np.prod(self.mesh_shape) + + def shape(self): + if self.axis_names is None: + return OrderedDict( + (dim, size) for dim, size in enumerate(self.mesh_shape)) + return OrderedDict( + (name, size) for name, size in zip(self.axis_names, self.mesh_shape)) + + def get_logical_mesh(self): + return self.device_ids.reshape(self.mesh_shape) + + def get_axis_name_idx(self, name: str) -> int: + if name not in self.axis_names: + return None + return self.axis_names.index(name) + + @functools.lru_cache(maxsize=None) + def get_op_sharding(self, + partition_spec: Tuple, + flatten_opsharding=False) -> torch_xla._XLAC.OpSharding: + """ + Return the OpSharding for the given partition spec. This is an expensive + operation as the mesh grows, so the value is cached for reuse. + """ + partition_spec = _translate_named_partition_spec(self, partition_spec) + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + + tile_assignment = _get_tile_assignment(self, partition_spec) + if len(tile_assignment.shape) > len(partition_spec): + # Use partial replication for sharding a tensor over a higher-rank mesh + sharding_type = ShardingType.PARTIAL + else: + sharding_type = _get_sharding_type(partition_spec, self.size()) + replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} + group_assignment, replication_groups = _get_group_assignment( + sharding_type, tile_assignment, len(partition_spec), replicate_dims) + + # If flatten_opsharding = True, return the flattened version of OpSharding + if flatten_opsharding: + return (tile_assignment.tolist(), group_assignment, replication_groups, + int(sharding_type)) + else: + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), + group_assignment, replication_groups, + int(sharding_type)) + + +# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 + + +class HybridMesh(Mesh): + """Creates a hybrid device mesh of devices connected with ICI and DCN networks. + The shape of logical mesh should be ordered by increasing network-intensity + e.g. [replica, data, model] where mdl has the most network communication + requirements. + + Args: + ici_mesh_shape: shape of the logical mesh for inner connected devices. + dcn_mesh_shape: shape of logical mesh for outer connected devices. + + Example: + # This example is assuming 2 slices of v4-8. + ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor) + dcn_mesh_shape = (2, 1, 1) + + mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor')) + print(mesh.shape()) + >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)]) + """ + ici_mesh_shape: Tuple[int, ...] + dcn_mesh_shape: Tuple[int, ...] + + def __init__(self, + *, + ici_mesh_shape: Tuple[int, ...], + dcn_mesh_shape: Tuple[int, ...] = None, + axis_names: Tuple[str, ...] = None): + if dcn_mesh_shape == None: + dcn_mesh_shape = tuple([1] * len(ici_mesh_shape)) + assert len(ici_mesh_shape) == len(dcn_mesh_shape) + mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)]) + self.device_attributes = xr.global_runtime_device_attributes() + self.device_attributes.sort( + key=lambda attr: xm.parse_xla_device(attr['name'])[1]) + + if 'slice_index' in self.device_attributes[0] and np.prod( + dcn_mesh_shape) == 1: + raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice') + if 'slice_index' not in self.device_attributes[0] and np.prod( + dcn_mesh_shape) > 1: + raise ValueError('Invalid dcn_mesh_shape for single slice mesh') + self.ici_mesh_shape = ici_mesh_shape + self.dcn_mesh_shape = dcn_mesh_shape + if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[ + 0]: # multislice + mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape, + self.dcn_mesh_shape) + else: + mesh = self._create_device_mesh(self.ici_mesh_shape) + device_ids = mesh.flatten() + super().__init__(device_ids, mesh_shape, axis_names) + + # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172 + def _get_physical_tpu_mesh(self, devices: Sequence[int]) -> np.ndarray: + r"""Rearrange TPU devices in a slice into a physical mesh. + + Args: + devices: A list of device logical ordinals in a TPU slice. + + Returns: + A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On + v2 and v3, global_z is instead cores_per_chip (i.e., 2). + """ + assert xm.xla_device_hw(xm.xla_device()) == 'TPU' + # coords is a 3-dims tuple representing the device in physical mesh + device_coords = [self.device_attributes[d]['coords'] for d in devices] + dims = tuple(d + 1 for d in max(device_coords)) + out = np.empty(dims, dtype=int) + for coords, d in zip(device_coords, devices): + out[coords[0], coords[1], coords[2]] = d + return out + + # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64. + def _create_device_mesh_for_nd_torus( + self, physical_mesh: np.ndarray, + mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]: + """Assigns logical parallelism axes to physical axes of an N-D torus network. + + Given logical parallelism axes with sizes in `mesh_shape` and devices in an + N-dimensional torus network represented by `physical_mesh`, maps each logical + axis to one or more physical axes. Prefer to map more-performance-sensitive + logical axes to larger numbers of physical axes to maximize the bandwidth + available to them. Also prefer to assign logical axes to multiple physical + axes of the same size (e.g., a 2D square) rather than multiple physical axes + of different sizes when possible. + + Note that this routine will never split a physical axis over more than one + logical axis (which would reduce total usable bandwidth but may sometimes be + desired anyway). As a result, it will error out in cases where this is + necessary to produce a valid mapping. + + Let's use a concrete example to explain the concepts and considerations. + + As an example, suppose the logical mesh is [data, model], for data and model + parallelism respectively. Also suppose that data parallelism is less + performance sensitive than model parallelism. Consider a 3D TPU pod slice of + shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). + + A TPU pod slice has equal bandwidth along all axes with wraparound links, but + a 2D plane of size 4x4 may have faster XLA collective implementations than a + non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want + the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. + + Args: + physical_mesh: a np.ndarray of devices in the shape of the N-D torus + physical topology. + mesh_shape: shape of the logical mesh (size of the various logical + parallelism axes), with axes ordered by increasing network intensity. + + Returns: + An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with + each logical parallelism axis mapped to one or more physical mesh axes. + The axis assignment (a list of length num_logical_axes, whose elements + are tuples representing physical axis indices). + """ + # Remaining physical axes to be assigned to logical axes. + assignable_physical_mesh = list(physical_mesh.shape) + # Map each logical axis to a subset of physical axes. + assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape] + # Assign logical axes from highest network intensity to lowest. + # `mesh_shape` is assumed to ordered by lowest network intensity first, so + # reverse it first. + # Assigns devices to 2D or 3D logical mesh. + for logical_axis_index, logical_axis_size in reversed( + list(enumerate(mesh_shape))): + for num_axes in range(3, 0, -1): + # map a combination of devices in physical axes to the logical axis. + axes = itertools.combinations(assignable_physical_mesh, num_axes) + indices = itertools.combinations( + range(len(assignable_physical_mesh)), num_axes) + for c_axes, c_indices in zip(axes, indices): + if np.product(c_axes) == logical_axis_size: + assignment[logical_axis_index] = c_indices + # Zero the assigned physical axes. + assignable_physical_mesh = [ + 0 if i in c_indices else v + for i, v in enumerate(assignable_physical_mesh) + ] + break + if assignment[logical_axis_index]: + # We already found an assignment from one candidate above. + break + else: + # If the num_axes for loop did not break, i.e. none of the candidates work + # goto here with this while-else construct. + if logical_axis_size > 1: + raise NotImplementedError( + 'Failed to find assignment for logical_axis_index' + f' {logical_axis_index} of size {logical_axis_size} with remaining' + f' assignable mesh {assignable_physical_mesh}. The size of each' + ' axis in your logical mesh must be equal to the product of' + ' some subset of the physical mesh axis sizes. E.g logical mesh (4,' + ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.' + ) + # Flatten the assignment + transpose: List[int] = [] + for x in assignment: + for y in x: + transpose.append(int(y)) + return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment + + def _create_device_mesh(self, + mesh_shape: Sequence[int], + devices: Sequence[Any] = None) -> Sequence[int]: + """Creates a performant device mesh. + + Args: + mesh_shape: shape of logical mesh, ordered by increasing network-intensity + e.g. [replica, data, mdl] where mdl has the most network communication + requirements. + devices: optionally, the devices to construct a mesh for. + + Returns: + A np.ndarray of devices with mesh_shape as its shape. + """ + + if devices is None: + devices = np.arange(xr.global_runtime_device_count()) + if np.prod(mesh_shape) != len(devices): + raise ValueError( + f'Number of devices {len(devices)} must equal the product ' + f'of mesh_shape {mesh_shape}') + physical_mesh = self._get_physical_tpu_mesh(devices) + device_mesh, assignment = self._create_device_mesh_for_nd_torus( + physical_mesh, mesh_shape) + return device_mesh + + # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288. + def _create_hybrid_device_mesh( + self, ici_mesh_shape: Sequence[int], + dcn_mesh_shape: Sequence[int]) -> Sequence[int]: + """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. + + Args: + ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered + by increasing network intensity, e.g. [replica, data, mdl] where mdl has + the most network communication requirements. + dcn_mesh_shape: shape of the logical mesh for the slower/outer network, + in the same order as mesh_shape. + + Returns: + A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape + that can be fed into HybridMesh for hybrid parallelism. + """ + granule_dict = defaultdict(list) + for d, dev in enumerate(self.device_attributes): + granule_dict[dev['slice_index']].append(d) + # sorts devices based on slice_index. + granules = list(granule_dict[key] for key in sorted(granule_dict.keys())) + if np.prod(dcn_mesh_shape) != len(granules): + raise ValueError( + f'Number of slices {len(granules)} must equal the product of ' + f'dcn_mesh_shape {dcn_mesh_shape}') + # creates a seperate internal mesh for each slice. + per_granule_meshes = [ + self._create_device_mesh(ici_mesh_shape, granule) + for granule in granules + ] + granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) + blocks = np.vectorize( + lambda i: per_granule_meshes[i], otypes=[object])( + granule_mesh) + device_mesh = np.block(blocks.tolist()) + return device_mesh + + +class ShardingType(IntEnum): + # ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX) + REPLICATED = 0 + MAXIMAL = 1 + TUPLE = 2 + TILED = 3 + MANUAL = 4 + PARTIAL = 5 + + +def _get_sharding_type(partition_spec: Tuple[Union[int, None]], + num_devices: int) -> ShardingType: + sharding_type = ShardingType.TILED + if num_devices == 1: + sharding_type = ShardingType.MAXIMAL + elif all(d is None for d in partition_spec): + sharding_type = ShardingType.REPLICATED + elif any(d is None for d in partition_spec): + sharding_type = ShardingType.PARTIAL + return sharding_type + + +def _get_tile_assignment( + mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int, + None]]) -> np.ndarray: + """ + Permute the given mesh to create the tile assignment based on the partition + spec. Returns the tiling assignment as a numpy ndarray. + + If the input partition_spec combines multiple logical mesh axes over a single + tensor axis, the resulting tiling assignment will combine the specified axes + into a single axis. + """ + # Flatten the partition spec and ensure that it is fully specified over the + # mesh for permutation. + tiled_dims = [x for x in partition_spec if x is not None] + permutation = np.hstack(tiled_dims).tolist() if tiled_dims else [] + missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation)) + tile_assignment = mesh.get_logical_mesh().transpose(permutation + + missing_axes) + + # For any tuples in the partition_spec, the grouped axes will be adjacent + # after the permutation. Combine these dimensions into a single axis. + for i, spec in enumerate(tiled_dims): + if isinstance(spec, tuple): + shape = tile_assignment.shape + tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) + + shape[i + len(spec):]) + + return tile_assignment + + +# Produce group assignment for partial replication. Partial replication tiles +# groups (a.k.a. sub-groups) where the shards are fully replicated within each +# sub-group. `replication_groups` is a list of groups as lists, where each group +# contains the participating device IDs. `group_assignment` describes the group +# placement and the overall mesh, where each element is the group ID. +# The tile_assignment should be the result of `_get_tile_assignment` so that all +# tiled dimensions are in the first axes and replicated dimensions are in the +# remaining axes. +def _get_group_assignment(sharding_type: ShardingType, + tile_assignment: np.ndarray, tensor_rank: int, + replicate_dims: Set[int]) -> Tuple[List, List]: + group_assignment = list() + replication_groups = list() + if sharding_type is ShardingType.PARTIAL: + # Shard across groups and replicate within subgroups; replicated dims + # will be used to group replication devices. + tile_shape = tile_assignment.shape + # When creating the tile assignment, the mesh is permuted so that the first + # few axes are used for tiling. + tile_dims = range(tensor_rank - len(replicate_dims)) + group_list = [tile_assignment] + for d in tile_dims: + _group_list = list() + for group_members in group_list: + _group_list += np.split(group_members, tile_shape[d], d) + group_list = _group_list + replication_groups = [group.flatten().tolist() for group in group_list] + + mesh_axis = itertools.count() + group_tile_shape = [ + 1 if d in replicate_dims else tile_shape[next(mesh_axis)] + for d in range(tensor_rank) + ] + group_assignment = np.arange(len(replication_groups)).reshape( + tuple(group_tile_shape)).tolist() + return group_assignment, replication_groups + + +def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): + _partition_spec = list() + for p in partition_spec: + if type(p) is tuple: + assert not any(type(x) is tuple + for x in p), 'Partition spec cannot contain nested tuples' + _partition_spec.append(_translate_named_partition_spec(mesh, p)) + elif (p is None) or (type(p) is int): + _partition_spec.append(p) + elif type(p) is str: + idx = mesh.get_axis_name_idx(p) + if idx is None: + raise ValueError(f"Axis name {p} is not defined in the given mesh") + _partition_spec.append(idx) + else: + raise ValueError( + f"Spec type {type(p)} is not supported in partition spec") + return tuple(_partition_spec) + + +@xr.requires_pjrt +def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], + mesh: Mesh, + partition_spec: Tuple[Union[Tuple, int, str, None]], + use_dynamo_custom_op: bool = False) -> XLAShardedTensor: + """ + Annotates the tensor provided with XLA partition spec. Internally, + it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. + Args: + t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec. + + mesh (Mesh): describes the logical XLA device topology and the underlying device IDs. + + partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or + `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str. + This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). + When a tuple is specified, the corresponding input tensor axis will be sharded along all + logical axes in the tuple. Note that the order the mesh axes are specified in the tuple + will impact the resulting sharding. + For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. + >> input = torch.randn(8, 10) + >> mesh_shape = (4, 2) + >> partition_spec = (0, None) + + dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding + to make itself recognizeable and traceable by dynamo. + + Examples + —------------------------------ + mesh_shape = (4, 2) + num_devices = xr.global_runtime_device_count() + device_ids = np.array(range(num_devices)) + mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + + # 4-way data parallel + input = torch.randn(8, 32).to(xm.xla_device()) + xs.mark_sharding(input, mesh, (0, None)) + + # 2-way model parallel + linear = nn.Linear(32, 10).to(xm.xla_device()) + xs.mark_sharding(linear.weight, mesh, (None, 1)) + """ + num_devices = xr.global_runtime_device_count() + assert num_devices > 0, "This requires XLA supported device(s)." + assert mesh.size() == num_devices, \ + f"{mesh.mesh_shape} is not mappable over {num_devices} devices." + # We only allow fully specified `partition_spec` to be applicable, as opposed + # to filling in the unspecified replicated dims. Fully specified `partiion_spec` + # should be of the same rank as `t`. This is to support partial replication + # where the group assignment may vary with different input ranks. + assert len(t.shape) == len(partition_spec), \ + f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." + + if use_dynamo_custom_op: + tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( + partition_spec, flatten_opsharding=True) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t.global_tensor, tile_assignment, group_assignment, + replication_groups, sharding_type) + return t + else: + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t, tile_assignment, group_assignment, replication_groups, + sharding_type) + return XLAShardedTensor(t) + else: + op_sharding = mesh.get_op_sharding(partition_spec) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + return t + else: + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) + + +def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: + """Clear sharding annotation from the input tensor and return a `cpu` casted tensor.""" + torch_xla._XLAC._xla_clear_sharding(t) + if isinstance(t, XLAShardedTensor): + return t.global_tensor + return t + + +def wrap_if_sharded(x: Any) -> Any: + """ + If the input is a sharded tensor, return an XLAShardedTensor wrapping it. + Otherwise, returns the input. + """ + if (isinstance(x, torch.Tensor) and not isinstance(x, XLAShardedTensor) and + x.device.type == 'xla' and + torch_xla._XLAC._get_xla_sharding_type(x) is not None): + return XLAShardedTensor(x) + return x + + +@dataclass +class ShardingSpec: + mesh: Mesh + partition_spec: Tuple[Union[int, None]] + minibatch: Optional[bool] = False + + # Derived fields + _tile_assignment: List[int] = field(init=False) + _group_assignment: List[int] = field(init=False) + _replication_groups: List[int] = field(init=False) + _sharding_type: ShardingType = field(init=False) + + @xr.requires_pjrt + def __post_init__(self): + mesh = self.mesh + partition_spec = _translate_named_partition_spec(mesh, self.partition_spec) + tile_assignment = _get_tile_assignment(mesh, partition_spec) + self._tile_assignment = tile_assignment.tolist() + self._sharding_type = _get_sharding_type(partition_spec, + xr.global_runtime_device_count()) + replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} + self._group_assignment, self._replication_groups = _get_group_assignment( + self._sharding_type, tile_assignment, len(partition_spec), + replicate_dims) + + def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: + """ + Create an XlaShardingSpec for the given tensor. If the tensor is + incompatible with the ShardingSpec, returns None. + """ + if not self.can_apply(t): + return None + return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, + self._group_assignment, + self._replication_groups, + int(self._sharding_type), + self.minibatch) + + def can_apply(self, t: torch.Tensor) -> bool: + """ + Test whether the ShardingSpec is compatible with the given torch.Tensor. + """ + return len(t.shape) == len(self.partition_spec) + + def apply(self, t: torch.Tensor): + # TODO(yeounoh) use virtual device interface when available. + assert (t.device == xm.xla_device()) + mark_sharding(t, self.mesh, self.partition_spec) + + +class XLAPatchedLinear(torch.autograd.Function): + """ + A patched version of `torch.nn.functional.linear` that uses einsum instead + of torch.matmul which will flatten the tensors to 2D and collide the sharded + dimensions. The torch.matmul default behavior makes it very hard for XLA compiler + to propagate the sharding annotation. + + TODO (alanwaketan): Let's patch it on the dispatcher level. + """ + + @staticmethod + def forward(ctx, input, weight, bias=None): + # bias is an optional argument + ctx.save_for_backward(input, weight, bias) + with torch.no_grad(): + product = torch.einsum('...n,mn->...m', input, weight) + if bias is None: + return product + return product + bias + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = torch.einsum('...m,mn->...n', grad_output, weight) + if ctx.needs_input_grad[1]: + grad_weight = torch.einsum('...m,...n->mn', grad_output, input) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = torch.einsum('...m->m', grad_output) + + return grad_input, grad_weight, grad_bias + + +def xla_patched_nn_linear_forward(m, input): + return XLAPatchedLinear.apply(input, m.weight, m.bias) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index d448b09dd84..aa2769cb94d 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -1,6 +1,7 @@ import torch import torch.distributed as dist import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr from torch_xla._internal import rendezvous import logging import os @@ -8,6 +9,8 @@ def _create_xla_process_group(prefix_store, rank, size, timeout): + assert not xr.is_spmd( + ), "XLA backend is not supported with SPMD. Please use a CPU process group instead." return ProcessGroupXla(prefix_store, rank, size, timeout) diff --git a/torch_xla/experimental/distributed_checkpoint/__init__.py b/torch_xla/experimental/distributed_checkpoint/__init__.py index 7c91aba0126..cad57c3a405 100644 --- a/torch_xla/experimental/distributed_checkpoint/__init__.py +++ b/torch_xla/experimental/distributed_checkpoint/__init__.py @@ -1,6 +1,8 @@ +from .manager import CheckpointManager from .planners import SPMDSavePlanner, SPMDLoadPlanner __all__ = [ + "CheckpointManager", "SPMDSavePlanner", "SPMDLoadPlanner", ] diff --git a/torch_xla/experimental/distributed_checkpoint/_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py index b49e7419dcd..62c3c6f2ee0 100644 --- a/torch_xla/experimental/distributed_checkpoint/_helpers.py +++ b/torch_xla/experimental/distributed_checkpoint/_helpers.py @@ -5,7 +5,7 @@ import dataclasses import torch -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs from torch.distributed.checkpoint.planner import SavePlan from typing import ( @@ -23,7 +23,7 @@ ) from torch.distributed.checkpoint.metadata import (MetadataIndex, STATE_DICT_TYPE) -from torch_xla.experimental.xla_sharding import XLAShardedTensor, ShardingType +from torch_xla.distributed.spmd import XLAShardedTensor, ShardingType from torch.utils._pytree import tree_map PATH_ITEM = Union[str, int] @@ -34,8 +34,13 @@ CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] +# TODO(jonbolin): Logic here is modified from the upstream to enable async +# checkpointing. If the state_dict is comprised entirely of _CpuShards, +# flatten_state_dict will not actually flatten the dict. +# Once we can represent XLAShardedTensor on CPU, either directly or through +# DistributedTensor, we can reuse the upstream logic. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: - return isinstance(value, torch.Tensor) + return isinstance(value, torch.Tensor) or isinstance(value, _CpuShards) def _traverse_state_dict( diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py new file mode 100644 index 00000000000..89bb20f5076 --- /dev/null +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -0,0 +1,341 @@ +import fsspec +import logging +import os +import pickle +import threading +import torch.distributed as dist +import torch.distributed.checkpoint as dist_cp +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.experimental.distributed_checkpoint as xc +import traceback + +from dataclasses import dataclass +from datetime import datetime +from collections import deque +from fsspec.core import url_to_fs +from os.path import basename +from concurrent.futures import ThreadPoolExecutor, wait +from typing import Deque, List, Optional, Union +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from ._helpers import _sharded_cpu_state_dict + +# TODO(jonbolin): Import path will change +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter + +# File to track manager-specific metadata within each checkpoint path +_MANAGER_METADATA_FILE = '.manager_metadata' + + +@dataclass +class _CheckpointMetadata: + # The step at which the checkpoint was taken + step: int + + # The time at which the checkpoint was taken + ts: datetime + + +class CheckpointManager: + """ + The CheckpointManager class provides a higher-level wrapper around the + torch.distributed.checkpoint APIs to manage checkpointing. It builds on top + of those APIs to enable a few key features: + - Per-step checkpointing: Each checkpoint taken by the CheckpointManager is + identified by the step at which it was taken, and any step tracked + by the CheckpointManager can be restored. + - Async checkpointing: The torch.distributed.checkpoint APIs are + synchronous, which will block training for the duration of the + checkpoint. The CheckpointManager's save_async method can be used to + offload checkpointing to a background thread, unblocking training + while the checkpoint is written to persistent storage. + - Automatic checkpointing: If the training process would be shut down due + to a SIGTERM, the CheckpointManager will automatically take a + checkpoint at the next step. + - Native fsspec integration: Any storage protocol compatible with fsspec + can be used with CheckpointManager. + + The intended usage of CheckpointManager is as follows: + + >>> # Create a CheckpointManager to checkpoint every 10 steps into GCS. + >>> chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10) + + >>> # Select a checkpoint to restore from, and restore if applicable + >>> tracked_steps = chkpt_mgr.all_steps() + >>> if tracked_steps: + >>> # Choose the highest step + >>> best_step = max(tracked_steps) + >>> state_dict = {'model': model.state_dict()} + >>> chkpt_mgr.restore(best_step, state_dict) + >>> model.load_state_dict(state_dict['model']) + + >>> # Call `save` or `save_async` every step within the train loop. + >>> for step, data in enumerate(dataloader): + >>> ... + >>> state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} + >>> if chkpt_mgr.save_async(step, state_dict): + >>> print(f'Checkpoint taken at step {step}') + + By calling `save` or `save_async` every step, the CheckpointManager has the + opportunity to take a checkpoint on steps which are out-of-cycle with its + step_period, as would be the case in auto checkpointing. + + This class is inspired by Orbax's CheckpointManager, which can be found here: + https://github.com/google/orbax/blob/efc079c/checkpoint/orbax/checkpoint/checkpoint_manager.py + """ + + # The base path to write checkpoints to. Each checkpoint taken by the manager + # will be written into a subdirectory of this path, identified by the + # checkpoint's step. + base_path: Union[str, os.PathLike] + + # The interval to take checkpoints, in steps. + save_interval: int + + # The maximum number of checkpoints to keep. + max_to_keep: int + + # Whether a checkpoint should be taken when a preemption is detected. + chkpt_on_preemption: bool + + def __init__(self, + path: str, + save_interval: int, + max_to_keep: Optional[int] = 0, + max_pending_async: Optional[int] = 1, + process_group: dist.ProcessGroup = None, + chkpt_on_preemption: bool = True): + """ + Create a checkpoint manager that reads and writes checkpoints into + the provided directory. + + Args: + path: The base path for the CheckpointManager to write checkpoints into. + save_interval: The number of steps between saving checkpoints. + max_to_keep: The maximum number of checkpoints to be tracked by the + CheckpointManager. When a new checkpoint will be taken, the + checkpoint for the lowest tracked step will be deleted. + Default: 0, indicating no upper bound on the number of checkpoints. + max_pending_async: The maximum number of async checkpoints which can be + pending. This should be a small value to ensure training doesn't + get too far ahead of the last finished checkpoint, but increasing + the value can unblock training when there are transient issues which + slow down the active checkpoint. + Default: 1, which only allows a single async checkpoint to be + pending at a time. + process_group: The process group to use when coordinating the checkpoint. + Default: None, in which case a subgroup of the default process + group will be created. + chkpt_on_preemption: Whether or not to take a checkpoint when a + preemption has been detected. + Default: True + """ + assert dist.is_initialized(), "A process group is required." + assert save_interval > 0, "save_interval must be positive" + assert max_pending_async > 0, "max_pending_async must be positive" + assert max_to_keep >= 0, "max_to_keep must be non-negative" + + self.base_path = os.path.join(path, '') # Ensure the base path ends in '/' + self.save_interval = save_interval + self.max_to_keep = max_to_keep + self.chkpt_on_preemption = chkpt_on_preemption + + # Create a new group if none is provided + # TODO(jonbolin): Verify subgroup on GPU backend + self.pg = process_group or dist.new_group() + + # Thread pool to run the async checkpoints. `_async_sem` is used to guard + # the number of pending checkpoints, and `_async_futures` tracks all + # futures returned by the pool. + self._async_worker_pool = ThreadPoolExecutor(max_workers=1) + self._async_sem = threading.Semaphore(max_pending_async) + self._async_futures = [] + # Mutex to ensure only a single thread can write a checkpoint at a time. + self._save_mutex = threading.Lock() + + self._tracked_chkpts = self._load_tracked_chkpts() + + if self.chkpt_on_preemption: + # Initialize the distributed runtime for preemption detection + torch_xla._XLAC._ensure_xla_coordinator_initialized( + xr.process_index(), xr.process_count(), xr.get_master_ip()) + torch_xla._XLAC._activate_preemption_sync_manager() + + def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: + """ + Loads a list of all tracked checkpoints from the storage backend. + """ + all_chkpts = [] + invalid_paths = [] + fs, raw_path = url_to_fs(self.base_path) + if not fs.exists(raw_path): + fs.mkdir(raw_path) + else: + for path in fs.ls(raw_path, detail=False): + try: + with fs.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f: + all_chkpts.append(pickle.load(f)) + except: + invalid_paths.append(path) + + if invalid_paths: + logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}') + return deque(sorted(all_chkpts, key=lambda m: m.ts)) + + def _get_path(self, step: int) -> str: + return os.path.join(self.base_path, str(step)) + + def _delete_chkpt_at_step(self, step): + path = self._get_path(step) + fs, raw_path = url_to_fs(path) + if fs.exists(raw_path): + fs.rm(raw_path, recursive=True) + + def _release_oldest_checkpoints(self): + """ + Delete oldest checkpoints until the number of tracked checkpoints is below + self.max_to_keep. This operation is only execution on the rank 0 process. + """ + if dist.get_rank(self.pg) == 0 and self.max_to_keep > 0: + while len(self._tracked_chkpts) > self.max_to_keep: + oldest_chkpt = self._tracked_chkpts.popleft() + self._delete_chkpt_at_step(oldest_chkpt.step) + + def _wait_for_data(self): + xm.mark_step() + xm.wait_device_ops() + + def _save(self, step, state_dict): + """ + The actual checkpointing logic, which is shared between async and + synchronous checkpointing. + + The caller must ensure that data is accessible within the state_dict before + calling, which can be achieved with `self._wait_for_data`. + """ + with self._save_mutex: + path = self._get_path(step) + # Delete any existing checkpoint at the current step. + self._delete_chkpt_at_step(step) + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=FsspecWriter(path), + planner=xc.SPMDSavePlanner(), + process_group=self.pg, + ) + metadata = _CheckpointMetadata(step=step, ts=datetime.now()) + self._tracked_chkpts.append(metadata) + if dist.get_rank(self.pg) == 0: + with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: + pickle.dump(metadata, f) + self._release_oldest_checkpoints() + + def should_save(self, step: int) -> bool: + """ + Returns true if a checkpoint should be saved for the current step. A + checkpoint should be taken if any of the following conditions are met: + - The step aligns with the CheckpointManager's save_interval. + - The CheckpointManager was created with the `chkpt_on_preemption` option + and a preemption has been detected. + """ + preemption_detected = False + if self.chkpt_on_preemption and self.reached_preemption(step): + logging.warn( + f"Preemption sync point reached at step {step}. Triggering a checkpoint." + ) + preemption_detected = True + return step % self.save_interval == 0 or preemption_detected + + def save(self, + step, + state_dict: STATE_DICT_TYPE, + force: Optional[bool] = False) -> bool: + """ + Take a checkpoint synchronously if `self.should_save(step)`. + + Args: + step: The current training step. + state_dict: The state dict to be checkpointed. + force: Option to force a checkpoint to be taken regardless of the result + of `should_save(step)`. + Returns: + True if a checkpoint was taken and False otherwise. + """ + if self.should_save(step) or force: + self._wait_for_data() + self._save(step, state_dict) + return True + return False + + def save_async(self, + step: int, + state_dict: STATE_DICT_TYPE, + force: Optional[bool] = False) -> bool: + """ + Take a checkpoint asynchronously if `self.should_save(step)`. The + input state_dict will be transferred to the CPU device using the + `sharded_cpu_state_dict` function. + + This function will do the following: + 1. Transfer `state_dict` to the CPU device. + 2. Dispatch the checkpoint workload to an asynchronous execution + queue. This will block training until the ongoing async + checkpoint finishes when the queue is full. + + Args: + step: The current training step. + state_dict: The state dict to be checkpointed. + force: Option to force a checkpoint to be taken regardless of the result + of `should_save(step)`. + Returns: + True if a checkpoint was taken and False otherwise. + """ + if self.should_save(step) or force: + self._wait_for_data() + # Move the state_dict to CPU + cpu_state_dict = _sharded_cpu_state_dict(state_dict) + self._async_sem.acquire() + future = self._async_worker_pool.submit(self._save, step, cpu_state_dict) + future.add_done_callback(lambda _: self._async_sem.release()) + self._async_futures.append(future) + return True + return False + + def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: + """ + Restores the checkpoint taken at the given step into the state_dict. The + caller is responsible for calling `model.load_state_dict` to restore any + non-tensor values. + + Args: + step: The step whose checkpoint is to be restored. + state_dict: The state dict to restore the checkpoint into. Values are + updated in-place within the state_dict. + """ + tracked_steps = set(x.step for x in self._tracked_chkpts) + assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}' + path = self._get_path(step) + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=FsspecReader(path), + planner=xc.SPMDLoadPlanner(), + process_group=self.pg, + ) + + def all_steps(self) -> List[int]: + """ + List all steps tracked by the CheckpointManager. + """ + return sorted(x.step for x in self._tracked_chkpts) + + def join(self): + """ Wait for any pending async checkpoints to complete. """ + wait(self._async_futures) + + def reached_preemption(self, step: int) -> bool: + """ Returns True if a preemption has been detected at the given step. """ + assert self.chkpt_on_preemption, ( + "Preemption detection not enabled. Please set `chkpt_on_preemption` " + " when creating the CheckpointManager") + return torch_xla._XLAC._sync_point_reached(step) diff --git a/torch_xla/experimental/distributed_checkpoint/planners.py b/torch_xla/experimental/distributed_checkpoint/planners.py index fbf466ff28a..6810ddb56a3 100644 --- a/torch_xla/experimental/distributed_checkpoint/planners.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -4,7 +4,7 @@ import numpy as np import torch import torch_xla -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs from collections import ChainMap from torch.distributed.checkpoint.default_planner import ( @@ -34,7 +34,7 @@ ) from torch.distributed.checkpoint.utils import find_state_dict_object from torch.utils._pytree import tree_map -from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard +from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard from torch_xla.experimental.distributed_checkpoint._helpers import ( FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) diff --git a/torch_xla/experimental/xla_sharded_tensor.py b/torch_xla/experimental/xla_sharded_tensor.py index 1c3eaf34916..a1f2e71d98b 100644 --- a/torch_xla/experimental/xla_sharded_tensor.py +++ b/torch_xla/experimental/xla_sharded_tensor.py @@ -1,167 +1,10 @@ -import torch -from torch.utils._pytree import tree_map -import torch_xla +# Keep this for backward compatibility. +# TODO(yeounoh) remove after 2.2 release. +import warnings -from dataclasses import dataclass -from typing import List, Tuple, Iterator, Union -import contextlib -import collections +warnings.warn( + "Importing from `torch_xla.experimental.xla_sharded_tensor` will be deprecated " + "after 2.2 release. Please use `torch_xla.distributed.spmd` " + "instead.", DeprecationWarning, 2) - -@dataclass -class XLAShard: - # A snapshot of the shard data from the time of XLAShard creation. - data: torch.Tensor - - # The indices of the shard into the global tensor. If the tensor is replicated - # across local devices, the value of `indices` is Ellipsis. Otherwise, it is a - # list of the index slices across each dimension. - # The indices do not reflect padding, since the padding does not exist on the - # global tensor. - indices: Union[type(Ellipsis), List[slice]] - - # The device this shard's data originated from. - shard_device: str - - # The replica this shard belongs to, as determined by the sharding. The - # replica is determined differently for each sharding type: - # - TILED: Since the tensor isn't replicated, replica_id is always 0. - # - PARTIAL: replica_id is taken from the OpSharding and is a value in - # the range [0, num_replica). - # - REPLICATED: Since the tensor is fully replicated, replica_id is the - # device's global ordinal. - replica_id: int - - @property - def unpadded_data(self) -> torch.Tensor: - ''' Returns a copy of `data` with padding removed ''' - unpadded_indices = self.indices - # Replicated data has Ellipsis as indices - if self.indices != Ellipsis: - unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices] - return self.data[unpadded_indices] - - @unpadded_data.setter - def unpadded_data(self, t: torch.Tensor): - unpadded_indices = self.indices - if self.indices != Ellipsis: - unpadded_indices = [slice(0, s.stop - s.start) for s in self.indices] - self.data[unpadded_indices] = t - - -@contextlib.contextmanager -def no_dispatch() -> Iterator[None]: - guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] - try: - yield - finally: - del guard - - -class XLAShardedTensor(torch.Tensor): - """ - A wrapper around `torch.Tensor` with sharding annotation - for XLA SPMD auto-sharding. The wrapped tensors are unwrapped - for IR tracing and converted to HLO graph with sharding annotations; - XLA SPMDPartitioner takes a pass, propagating and injecting collectives - to the graph before compilation. - """ - - # XLAShardedTensor behaves like a unpartitioned, - # combined tensor on the host machine. When user annotates, - # this is simply set to the input tensor. When an XLA partitioned - # output tensor returns (or sharding propagated intermediate tensors) - # as XLAShardedTensor, the backend gathers global data across devices - # and materialize and set `global_tensor` on the host; the actual device - # data still remain on individual device as sharded or replicated. - # Note: we should drop this reference, and force all gather on each access. - global_tensor: torch.Tensor - # A logical device topology, each element describes - # a number of devices in the corresponding axis. - # NOTE: we could use more specific device-rank mapping, e.g., ShardingSpec, - # if needed. The change shouldn't be difficult, or create another constructor. - mesh_shape: Tuple[int] # TODO: create a wrapper for named axes - # Specifies how each input rank is sharded (index to mesh_shape) - # or replicated (None). For example, we can shard an 8x10 tensor - # 4-way row-wise, and replicate column-wise. - # >> input = torch.randn(8, 10) - # >> mesh_shape = (4, 2) - # >> assert np.prod(mesh_shape) == len(xm.get_xla_supported_devices()) - # >> partition_spec = (0, None) - # >> assert len(input.shape) == len(partition_spec) - partition_spec: Tuple[int, None] - - __slots__ = ['global_tensor'] - - @staticmethod - def __new__(cls, elem: torch.Tensor, *args, **kwargs): - # TODO(yeounoh) wrapper can take different arguments - r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - elem.size(), - strides=elem.stride(), - storage_offset=elem.storage_offset(), - dtype=elem.dtype, - layout=elem.layout, - device=elem.device, - requires_grad=kwargs.get("requires_grad", False)) - r.global_tensor = elem.detach() if r.requires_grad else elem - return r - - # Shards on the devices are materialized/available after the lazy - # execution of the partitioned HLO graph. Each XLAShard points - # to torch.Tensor. The shards represent a snapshot on CPU, detached - # from the global tensor. The shard data will contain any padding - # which results from the sharding. - @property - def local_shards(self) -> List[XLAShard]: - shards, devices = torch_xla._XLAC._get_local_shards(self.global_tensor) - replica_and_indices = torch_xla._XLAC._get_local_shard_replica_and_indices( - self.global_tensor) - zipped = zip(shards, replica_and_indices, devices) - return [ - XLAShard(data, indices, dev, replica) - for data, (replica, indices), dev in zipped - ] - - # Load the given list of local shards into the underlying tensor's data - # on the local devices. - def load_local_shards_(self, shards: List[XLAShard]): - data = [s.data for s in shards] - devices = [s.shard_device for s in shards] - torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices) - - @property - def sharding_spec(self): - return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor) - - @property - def sharding_type(self) -> 'ShardingType': - from torch_xla.experimental.xla_sharding import ShardingType - sharding_type = torch_xla._XLAC._get_xla_sharding_type(self.global_tensor) - return ShardingType(sharding_type) - - def __repr__(self): - return f"XLAShardedTensor({self.global_tensor})" - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - """ - The dispatcher allows the unwrapped torch.Tensor to re-dispatched to the - `xla` backend as XlaTensor, and the XlaTensor with an associated sharding spec - to be received and wrapped as XLAShardedTensor. - """ - - def unwrap(elem): - return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem - - def wrap(elem): - return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem - - # no_dispatch is only needed if you use enable_python_mode. - # It prevents infinite recursion. - with no_dispatch(): - # re-dispatch to C++ - rs = tree_map(wrap, - func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) - return rs +from torch_xla.distributed.spmd.xla_sharded_tensor import * diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 95f4a88128b..7b8c5d42b57 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -1,614 +1,10 @@ -import os -from collections import OrderedDict, defaultdict -from dataclasses import dataclass, field -import torch -import torch_xla -import torch_xla.core.xla_model as xm -from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor, XLAShard -import torch_xla.runtime as xr +# Keep this for backward compatibility. +# TODO(yeounoh) remove after 2.2 release. +import warnings -import numpy as np -import functools -import itertools -from typing import Tuple, Union, List, Sequence, Any, Optional, Set -from enum import IntEnum +warnings.warn( + "Importing from `torch_xla.experimental.xla_sharding` will be deprecated " + "after 2.2 release. Please use `torch_xla.distributed.spmd` instead.", + DeprecationWarning, 2) - -class Mesh: - """Describe the logical XLA device topology mesh and the underlying resources. - - Args: - device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped - to an `mesh_shape` array, filling the elements using C-like index order. - - mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape - of the device mesh, and each element describes the number of devices in - the corresponding axis. - - axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions - of the `devices` argument. Its length should match the rank of `devices`. - - Example: - —------------------------------ - mesh_shape = (4, 2) - num_devices = len(xm.get_xla_supported_devices()) - device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - mesh.get_logical_mesh() - >> array([[0, 1], - [2, 3], - [4, 5], - [6, 7]]) - mesh.shape() - >> OrderedDict([('x', 4), ('y', 2)]) - """ - - device_ids: np.ndarray - mesh_shape: Tuple[int, ...] - axis_names: Tuple[str, ...] - - def __init__(self, - device_ids: Union[np.ndarray, List], - mesh_shape: Tuple[int, ...], - axis_names: Tuple[str, ...] = None): - if not isinstance(device_ids, np.ndarray): - device_ids = np.array(device_ids) - assert (axis_names is None) or (len(mesh_shape) == len(axis_names)) - assert axis_names is None or (len(set(axis_names)) == len(axis_names)) - assert (len(device_ids) == np.prod(mesh_shape)) - assert len(device_ids) == len(np.unique(device_ids)) - self.device_ids = device_ids - self.mesh_shape = mesh_shape - self.axis_names = axis_names - assert all(d < self.size() for d in device_ids) - - def size(self): - return np.prod(self.mesh_shape) - - def shape(self): - if self.axis_names is None: - return OrderedDict( - (dim, size) for dim, size in enumerate(self.mesh_shape)) - return OrderedDict( - (name, size) for name, size in zip(self.axis_names, self.mesh_shape)) - - def get_logical_mesh(self): - return self.device_ids.reshape(self.mesh_shape) - - def get_axis_name_idx(self, name: str) -> int: - if name not in self.axis_names: - return None - return self.axis_names.index(name) - - @functools.lru_cache(maxsize=None) - def get_op_sharding(self, - partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: - """ - Return the OpSharding for the given partition spec. This is an expensive - operation as the mesh grows, so the value is cached for reuse. - """ - tile_assignment = _get_tile_assignment(self, partition_spec) - if len(tile_assignment.shape) > len(partition_spec): - # Use partial replication for sharding a tensor over a higher-rank mesh - sharding_type = ShardingType.PARTIAL - else: - sharding_type = _get_sharding_type(partition_spec, self.size()) - replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} - group_assignment, replication_groups = _get_group_assignment( - sharding_type, tile_assignment, len(partition_spec), replicate_dims) - return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) - - -# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 - - -class HybridMesh(Mesh): - """Creates a hybrid device mesh of devices connected with ICI and DCN networks. - The shape of logical mesh should be ordered by increasing network-intensity - e.g. [replica, data, model] where mdl has the most network communication - requirements. - - Args: - ici_mesh_shape: shape of the logical mesh for inner connected devices. - dcn_mesh_shape: shape of logical mesh for outer connected devices. - - Example: - # This example is assuming 2 slices of v4-8. - ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor) - dcn_mesh_shape = (2, 1, 1) - - mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor')) - print(mesh.shape()) - >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)]) - """ - ici_mesh_shape: Tuple[int, ...] - dcn_mesh_shape: Tuple[int, ...] - - def __init__(self, - *, - ici_mesh_shape: Tuple[int, ...], - dcn_mesh_shape: Tuple[int, ...] = None, - axis_names: Tuple[str, ...] = None): - if dcn_mesh_shape == None: - dcn_mesh_shape = tuple([1] * len(ici_mesh_shape)) - assert len(ici_mesh_shape) == len(dcn_mesh_shape) - mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)]) - self.device_attributes = xr.global_runtime_device_attributes() - self.device_attributes.sort( - key=lambda attr: xm.parse_xla_device(attr['name'])[1]) - - if 'slice_index' in self.device_attributes[0] and np.prod( - dcn_mesh_shape) == 1: - raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice') - if 'slice_index' not in self.device_attributes[0] and np.prod( - dcn_mesh_shape) > 1: - raise ValueError('Invalid dcn_mesh_shape for single slice mesh') - self.ici_mesh_shape = ici_mesh_shape - self.dcn_mesh_shape = dcn_mesh_shape - if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[ - 0]: # multislice - mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape, - self.dcn_mesh_shape) - else: - mesh = self._create_device_mesh(self.ici_mesh_shape) - device_ids = mesh.flatten() - super().__init__(device_ids, mesh_shape, axis_names) - - # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172 - def _get_physical_tpu_mesh(self, devices: Sequence[int]) -> np.ndarray: - r"""Rearrange TPU devices in a slice into a physical mesh. - - Args: - devices: A list of device logical ordinals in a TPU slice. - - Returns: - A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On - v2 and v3, global_z is instead cores_per_chip (i.e., 2). - """ - assert xm.xla_device_hw(xm.xla_device()) == 'TPU' - # coords is a 3-dims tuple representing the device in physical mesh - device_coords = [self.device_attributes[d]['coords'] for d in devices] - dims = tuple(d + 1 for d in max(device_coords)) - out = np.empty(dims, dtype=int) - for coords, d in zip(device_coords, devices): - out[coords[0], coords[1], coords[2]] = d - return out - - # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64. - def _create_device_mesh_for_nd_torus( - self, physical_mesh: np.ndarray, - mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]: - """Assigns logical parallelism axes to physical axes of an N-D torus network. - - Given logical parallelism axes with sizes in `mesh_shape` and devices in an - N-dimensional torus network represented by `physical_mesh`, maps each logical - axis to one or more physical axes. Prefer to map more-performance-sensitive - logical axes to larger numbers of physical axes to maximize the bandwidth - available to them. Also prefer to assign logical axes to multiple physical - axes of the same size (e.g., a 2D square) rather than multiple physical axes - of different sizes when possible. - - Note that this routine will never split a physical axis over more than one - logical axis (which would reduce total usable bandwidth but may sometimes be - desired anyway). As a result, it will error out in cases where this is - necessary to produce a valid mapping. - - Let's use a concrete example to explain the concepts and considerations. - - As an example, suppose the logical mesh is [data, model], for data and model - parallelism respectively. Also suppose that data parallelism is less - performance sensitive than model parallelism. Consider a 3D TPU pod slice of - shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). - - A TPU pod slice has equal bandwidth along all axes with wraparound links, but - a 2D plane of size 4x4 may have faster XLA collective implementations than a - non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want - the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. - - Args: - physical_mesh: a np.ndarray of devices in the shape of the N-D torus - physical topology. - mesh_shape: shape of the logical mesh (size of the various logical - parallelism axes), with axes ordered by increasing network intensity. - - Returns: - An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with - each logical parallelism axis mapped to one or more physical mesh axes. - The axis assignment (a list of length num_logical_axes, whose elements - are tuples representing physical axis indices). - """ - # Remaining physical axes to be assigned to logical axes. - assignable_physical_mesh = list(physical_mesh.shape) - # Map each logical axis to a subset of physical axes. - assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape] - # Assign logical axes from highest network intensity to lowest. - # `mesh_shape` is assumed to ordered by lowest network intensity first, so - # reverse it first. - # Assigns devices to 2D or 3D logical mesh. - for logical_axis_index, logical_axis_size in reversed( - list(enumerate(mesh_shape))): - for num_axes in range(3, 0, -1): - # map a combination of devices in physical axes to the logical axis. - axes = itertools.combinations(assignable_physical_mesh, num_axes) - indices = itertools.combinations( - range(len(assignable_physical_mesh)), num_axes) - for c_axes, c_indices in zip(axes, indices): - if np.product(c_axes) == logical_axis_size: - assignment[logical_axis_index] = c_indices - # Zero the assigned physical axes. - assignable_physical_mesh = [ - 0 if i in c_indices else v - for i, v in enumerate(assignable_physical_mesh) - ] - break - if assignment[logical_axis_index]: - # We already found an assignment from one candidate above. - break - else: - # If the num_axes for loop did not break, i.e. none of the candidates work - # goto here with this while-else construct. - if logical_axis_size > 1: - raise NotImplementedError( - 'Failed to find assignment for logical_axis_index' - f' {logical_axis_index} of size {logical_axis_size} with remaining' - f' assignable mesh {assignable_physical_mesh}. The size of each' - ' axis in your logical mesh must be equal to the product of' - ' some subset of the physical mesh axis sizes. E.g logical mesh (4,' - ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.' - ) - # Flatten the assignment - transpose: List[int] = [] - for x in assignment: - for y in x: - transpose.append(int(y)) - return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment - - def _create_device_mesh(self, - mesh_shape: Sequence[int], - devices: Sequence[Any] = None) -> Sequence[int]: - """Creates a performant device mesh. - - Args: - mesh_shape: shape of logical mesh, ordered by increasing network-intensity - e.g. [replica, data, mdl] where mdl has the most network communication - requirements. - devices: optionally, the devices to construct a mesh for. - - Returns: - A np.ndarray of devices with mesh_shape as its shape. - """ - - if devices is None: - devices = np.arange(xr.global_runtime_device_count()) - if np.prod(mesh_shape) != len(devices): - raise ValueError( - f'Number of devices {len(devices)} must equal the product ' - f'of mesh_shape {mesh_shape}') - physical_mesh = self._get_physical_tpu_mesh(devices) - device_mesh, assignment = self._create_device_mesh_for_nd_torus( - physical_mesh, mesh_shape) - return device_mesh - - # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288. - def _create_hybrid_device_mesh( - self, ici_mesh_shape: Sequence[int], - dcn_mesh_shape: Sequence[int]) -> Sequence[int]: - """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. - - Args: - ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered - by increasing network intensity, e.g. [replica, data, mdl] where mdl has - the most network communication requirements. - dcn_mesh_shape: shape of the logical mesh for the slower/outer network, - in the same order as mesh_shape. - - Returns: - A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape - that can be fed into HybridMesh for hybrid parallelism. - """ - granule_dict = defaultdict(list) - for d, dev in enumerate(self.device_attributes): - granule_dict[dev['slice_index']].append(d) - # sorts devices based on slice_index. - granules = list(granule_dict[key] for key in sorted(granule_dict.keys())) - if np.prod(dcn_mesh_shape) != len(granules): - raise ValueError( - f'Number of slices {len(granules)} must equal the product of ' - f'dcn_mesh_shape {dcn_mesh_shape}') - # creates a seperate internal mesh for each slice. - per_granule_meshes = [ - self._create_device_mesh(ici_mesh_shape, granule) - for granule in granules - ] - granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) - blocks = np.vectorize( - lambda i: per_granule_meshes[i], otypes=[object])( - granule_mesh) - device_mesh = np.block(blocks.tolist()) - return device_mesh - - -class ShardingType(IntEnum): - # ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX) - REPLICATED = 0 - MAXIMAL = 1 - TUPLE = 2 - TILED = 3 - MANUAL = 4 - PARTIAL = 5 - - -def _get_sharding_type(partition_spec: Tuple[Union[int, None]], - num_devices: int) -> ShardingType: - sharding_type = ShardingType.TILED - if num_devices == 1: - sharding_type = ShardingType.MAXIMAL - elif all(d is None for d in partition_spec): - sharding_type = ShardingType.REPLICATED - elif any(d is None for d in partition_spec): - sharding_type = ShardingType.PARTIAL - return sharding_type - - -def _get_tile_assignment( - mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int, - None]]) -> np.ndarray: - """ - Permute the given mesh to create the tile assignment based on the partition - spec. Returns the tiling assignment as a numpy ndarray. - - If the input partition_spec combines multiple logical mesh axes over a single - tensor axis, the resulting tiling assignment will combine the specified axes - into a single axis. - """ - # Flatten the partition spec and ensure that it is fully specified over the - # mesh for permutation. - tiled_dims = [x for x in partition_spec if x is not None] - permutation = np.hstack(tiled_dims).tolist() if tiled_dims else [] - missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation)) - tile_assignment = mesh.get_logical_mesh().transpose(permutation + - missing_axes) - - # For any tuples in the partition_spec, the grouped axes will be adjacent - # after the permutation. Combine these dimensions into a single axis. - for i, spec in enumerate(tiled_dims): - if isinstance(spec, tuple): - shape = tile_assignment.shape - tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) + - shape[i + len(spec):]) - - return tile_assignment - - -# Produce group assignment for partial replication. Partial replication tiles -# groups (a.k.a. sub-groups) where the shards are fully replicated within each -# sub-group. `replication_groups` is a list of groups as lists, where each group -# contains the participating device IDs. `group_assignment` describes the group -# placement and the overall mesh, where each element is the group ID. -# The tile_assignment should be the result of `_get_tile_assignment` so that all -# tiled dimensions are in the first axes and replicated dimensions are in the -# remaining axes. -def _get_group_assignment(sharding_type: ShardingType, - tile_assignment: np.ndarray, tensor_rank: int, - replicate_dims: Set[int]) -> Tuple[List, List]: - group_assignment = list() - replication_groups = list() - if sharding_type is ShardingType.PARTIAL: - # Shard across groups and replicate within subgroups; replicated dims - # will be used to group replication devices. - tile_shape = tile_assignment.shape - # When creating the tile assignment, the mesh is permuted so that the first - # few axes are used for tiling. - tile_dims = range(tensor_rank - len(replicate_dims)) - group_list = [tile_assignment] - for d in tile_dims: - _group_list = list() - for group_members in group_list: - _group_list += np.split(group_members, tile_shape[d], d) - group_list = _group_list - replication_groups = [group.flatten().tolist() for group in group_list] - - mesh_axis = itertools.count() - group_tile_shape = [ - 1 if d in replicate_dims else tile_shape[next(mesh_axis)] - for d in range(tensor_rank) - ] - group_assignment = np.arange(len(replication_groups)).reshape( - tuple(group_tile_shape)).tolist() - return group_assignment, replication_groups - - -def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): - _partition_spec = list() - for p in partition_spec: - if type(p) is tuple: - assert not any(type(x) is tuple - for x in p), 'Partition spec cannot contain nested tuples' - _partition_spec.append(_translate_named_partition_spec(mesh, p)) - elif (p is None) or (type(p) is int): - _partition_spec.append(p) - elif type(p) is str: - idx = mesh.get_axis_name_idx(p) - if idx is None: - raise ValueError(f"Axis name {p} is not defined in the given mesh") - _partition_spec.append(idx) - else: - raise ValueError( - f"Spec type {type(p)} is not supported in partition spec") - return tuple(_partition_spec) - - -@xr.requires_pjrt -def mark_sharding( - t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: - """ - Annotates the tensor provided with XLA partition spec. Internally, - it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. - Args: - t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec. - - mesh (Mesh): describes the logical XLA device topology and the underlying device IDs. - - partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or - `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str. - This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). - When a tuple is specified, the corresponding input tensor axis will be sharded along all - logical axes in the tuple. Note that the order the mesh axes are specified in the tuple - will impact the resulting sharding. - For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. - >> input = torch.randn(8, 10) - >> mesh_shape = (4, 2) - >> partition_spec = (0, None) - - Examples - —------------------------------ - mesh_shape = (4, 2) - num_devices = xr.global_runtime_device_count() - device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - - # 4-way data parallel - input = torch.randn(8, 32).to(xm.xla_device()) - xs.mark_sharding(input, mesh, (0, None)) - - # 2-way model parallel - linear = nn.Linear(32, 10).to(xm.xla_device()) - xs.mark_sharding(linear.weight, mesh, (None, 1)) - """ - num_devices = xr.global_runtime_device_count() - assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ - f"{mesh.mesh_shape} is not mappable over {num_devices} devices." - partition_spec = _translate_named_partition_spec(mesh, partition_spec) - # We only allow fully specified `partition_spec` to be applicable, as opposed - # to filling in the unspecified replicated dims. Fully specified `partiion_spec` - # should be of the same rank as `t`. This is to support partial replication - # where the group assignment may vary with different input ranks. - assert len(t.shape) == len(partition_spec), \ - f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - - op_sharding = mesh.get_op_sharding(partition_spec) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) - return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) - - -def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: - """Clear sharding annotation from the input tensor and return a `cpu` casted tensor.""" - torch_xla._XLAC._xla_clear_sharding(t) - if isinstance(t, XLAShardedTensor): - return t.global_tensor - return t - - -def wrap_if_sharded(x: Any) -> Any: - """ - If the input is a sharded tensor, return an XLAShardedTensor wrapping it. - Otherwise, returns the input. - """ - if (isinstance(x, torch.Tensor) and not isinstance(x, XLAShardedTensor) and - x.device.type == 'xla' and - torch_xla._XLAC._get_xla_sharding_type(x) is not None): - return XLAShardedTensor(x) - return x - - -@dataclass -class ShardingSpec: - mesh: Mesh - partition_spec: Tuple[Union[int, None]] - minibatch: Optional[bool] = False - - # Derived fields - _tile_assignment: List[int] = field(init=False) - _group_assignment: List[int] = field(init=False) - _replication_groups: List[int] = field(init=False) - _sharding_type: ShardingType = field(init=False) - - @xr.requires_pjrt - def __post_init__(self): - mesh = self.mesh - partition_spec = _translate_named_partition_spec(mesh, self.partition_spec) - tile_assignment = _get_tile_assignment(mesh, partition_spec) - self._tile_assignment = tile_assignment.tolist() - self._sharding_type = _get_sharding_type(partition_spec, - xr.global_runtime_device_count()) - replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} - self._group_assignment, self._replication_groups = _get_group_assignment( - self._sharding_type, tile_assignment, len(partition_spec), - replicate_dims) - - def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: - """ - Create an XlaShardingSpec for the given tensor. If the tensor is - incompatible with the ShardingSpec, returns None. - """ - if not self.can_apply(t): - return None - return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, - self._group_assignment, - self._replication_groups, - int(self._sharding_type), - self.minibatch) - - def can_apply(self, t: torch.Tensor) -> bool: - """ - Test whether the ShardingSpec is compatible with the given torch.Tensor. - """ - return len(t.shape) == len(self.partition_spec) - - def apply(self, t: torch.Tensor): - # TODO(yeounoh) use virtual device interface when available. - assert (t.device == xm.xla_device()) - mark_sharding(t, self.mesh, self.partition_spec) - - -class XLAPatchedLinear(torch.autograd.Function): - """ - A patched version of `torch.nn.functional.linear` that uses einsum instead - of torch.matmul which will flatten the tensors to 2D and collide the sharded - dimensions. The torch.matmul default behavior makes it very hard for XLA compiler - to propagate the sharding annotation. - - TODO (alanwaketan): Let's patch it on the dispatcher level. - """ - - @staticmethod - def forward(ctx, input, weight, bias=None): - # bias is an optional argument - ctx.save_for_backward(input, weight, bias) - with torch.no_grad(): - product = torch.einsum('...n,mn->...m', input, weight) - if bias is None: - return product - return product + bias - - @staticmethod - def backward(ctx, grad_output): - input, weight, bias = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - - if ctx.needs_input_grad[0]: - grad_input = torch.einsum('...m,mn->...n', grad_output, weight) - if ctx.needs_input_grad[1]: - grad_weight = torch.einsum('...m,...n->mn', grad_output, input) - if bias is not None and ctx.needs_input_grad[2]: - grad_bias = torch.einsum('...m->m', grad_output) - - return grad_input, grad_weight, grad_bias - - -def xla_patched_nn_linear_forward(m, input): - return XLAPatchedLinear.apply(input, m.weight, m.bias) +from torch_xla.distributed.spmd.xla_sharding import * diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 3087f3c80f6..2d0c280fd2a 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -40,8 +40,8 @@ def _maybe_select_default_device(): os.environ[xenv.PJRT_DEVICE] = 'TPU' # TODO(wcromar): Detect GPU device elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0: - logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=GPU') - os.environ[xenv.PJRT_DEVICE] = 'GPU' + logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=CUDA') + os.environ[xenv.PJRT_DEVICE] = 'CUDA' else: logging.warning('Defaulting to PJRT_DEVICE=CPU') os.environ[xenv.PJRT_DEVICE] = 'CPU' @@ -107,6 +107,13 @@ def xla_device(n: Optional[int] = None, Returns: A `torch.device` representing an XLA device. """ + # TODO(xiowei replace gpu with cuda): Remove the warning message at r2.2 release. + pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) + if pjrt_device.casefold() == 'gpu': + warnings.warn( + 'PJRT_DEVICE=GPU is being deprecate. Please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA.' + ) + if n is None: return torch.device(torch_xla._XLAC._xla_get_default_device()) @@ -242,3 +249,14 @@ def is_spmd(): """Returns if SPMD is set for execution.""" # TODO(yeounoh) replace this when we fully deprecate the flag. return xu.check_env_flag('XLA_USE_SPMD') + + +@requires_pjrt +def get_master_ip() -> str: + """Retrieve the master worker IP for the runtime. This calls into + backend-specific discovery APIs. + + Returns master worker's IP address as a string.""" + if device_type() == 'TPU': + return tpu.discover_master_worker_ip() + raise RuntimeError(f'IP discovery not supported for device: {device_type()}') diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 35b6b6cdf51..cd7085bebd1 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -73,10 +73,14 @@ def evaluate(self, method_name, args): res = pytree.tree_unflatten(res, out_spec) return res - def get_stablehlo_bytecode(self, method_name): + def get_stablehlo_bytecode(self, method_name=None): + if method_name is None: + method_name = self._default_method return self._name_to_stablehlo[method_name].bytecode - def get_stablehlo_text(self, method_name): + def get_stablehlo_text(self, method_name=None): + if method_name is None: + method_name = self._default_method return self._name_to_stablehlo[method_name].text def save(self, directory_path): @@ -135,6 +139,8 @@ class StableHLOFunctionMeta: # the arguments the user supplied, OR a parameter, OR a constant input_locations: List[InputLocation] + unused_inputs: List[Tuple[InputLocation, VariableSignature]] + # input_pytree_spec input_pytree_spec: Optional[str] = None output_pytree_spec: Optional[str] = None @@ -210,6 +216,16 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs['device'] = self._device return super().call_function(target, args, new_kwargs) + def run_node(self, n) -> Any: + if n.op == 'placeholder': + fake_t = n.meta['val'] + res = super().run_node(n) + for i, x in enumerate(fake_t.shape): + if not isinstance(x, int): + torch_xla._XLAC._xla_mark_dynamic(res, i) + return res + return super().run_node(n) + def _extract_input_args(exported_model, options): if options.override_tracing_arguments is not None: @@ -238,7 +254,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, options) -> StableHLOModelBundle: if options is None: options = StableHLOExportOptions() - + exported_model = exported_model.run_decompositions() input_args = _extract_input_args(exported_model, options) device = xm.xla_device() @@ -299,10 +315,16 @@ def _exported_program_to_stablehlo_bundle(exported_model, if isinstance(tensor, torch.Tensor) } + # there might be inputs that is part of input but not consumed by HLO graph + unused_input_positions = set(range(len(input_args))) + for hlo_input_pos, (tensor_id, tensor_value) in enumerate( zip(graph_input_tensor_ids, graph_input_xla_values)): if tensor_id in input_ids: # this is input - location = InputLocation.input_arg(position=input_ids[tensor_id]) + pos_id = input_ids[tensor_id] + location = InputLocation.input_arg(position=pos_id) + if pos_id in unused_input_positions: + unused_input_positions.remove(pos_id) elif tensor_id in tensor_id_to_state_name: location = InputLocation.parameter( name=tensor_id_to_state_name[tensor_id]) @@ -315,6 +337,21 @@ def _exported_program_to_stablehlo_bundle(exported_model, shape=list(tensor_value.shape), dtype=str(tensor_value.dtype).replace('torch.', ''))) + unused_inputs = [] + for i in unused_input_positions: + pos = InputLocation.input_arg(position=i) + arg = input_args[i] + if isinstance(arg, torch.Tensor): + signature = VariableSignature( + shape=list(arg.shape), dtype=str(arg.dtype).replace('torch.', '')) + else: + signature = VariableSignature( + shape=[], + dtype=str(type(arg)), + ) + + unused_inputs.append((pos, signature)) + output_signature = [ VariableSignature( shape=list(tensor.shape), @@ -330,6 +367,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, input_signature=input_signatures, output_signature=output_signature, input_locations=input_locations, + unused_inputs=unused_inputs, input_pytree_spec=pytree.treespec_dumps(exported_model.call_spec.in_spec), output_pytree_spec=pytree.treespec_dumps( exported_model.call_spec.out_spec), diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 2c40a21b7b5..511b9e02e9b 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -1,3 +1,4 @@ +import itertools import sys import os from typing import List, Tuple, Any @@ -35,16 +36,20 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule): - return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], - stablehlo_program._bundle) +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, + bundle=None): + if bundle is None: + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], + stablehlo_program._bundle) + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], bundle) def _make_input_signatures( meta: stablehlo.StableHLOFunctionMeta) -> List[tf.TensorSpec]: input_pos_to_spec = { loc.position: spec - for loc, spec in zip(meta.input_locations, meta.input_signature) + for loc, spec in itertools.chain( + zip(meta.input_locations, meta.input_signature), meta.unused_inputs) if loc.type_ == stablehlo.VariableType.INPUT_ARG } for i in range(len(input_pos_to_spec)): @@ -53,6 +58,19 @@ def _make_input_signatures( shape=spec.shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}') +def _mangle_tf_root_scope_name(name): + # TF has more restricted constrain on the variable names at root scope. + # Root scope name constrain: [A-Za-z0-9.][A-Za-z0-9_.\\-/]* + # Non-root scope name constrain: [A-Za-z0-9_.\\-/]* + # https://github.com/tensorflow/tensorflow/blob/51b601fa6bb7e801c0b6ae73c25580e40a8b5745/tensorflow/python/framework/ops.py#L3301-L3302 + # The state_dict key doesn't have such constrain, + # the name need to be mangled when a root-scoped TF variable is created. + if name[0] in "._\\-/": + return 'k' + name + else: + return name + + def save_stablehlo_graph_as_tf( stablehlo_program: stablehlo.StableHLOGraphModule, path: os.PathLike, @@ -76,7 +94,8 @@ def save_stablehlo_graph_as_tf( bundle = copy.deepcopy(stablehlo_program._bundle) tfm = tf.Module() bundle.state_dict = { - k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items() + k: tf.Variable(v, trainable=False, name=_mangle_tf_root_scope_name(k)) + for k, v in bundle.state_dict.items() } bundle.additional_constants = [ tf.Variable(v, trainable=False) for v in bundle.additional_constants @@ -84,7 +103,8 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), + input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}