Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.3] backport PR#6719 #6754

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5.3.0
6.5.0
11 changes: 6 additions & 5 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ function run_torch_xla_python_tests() {
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=$num_devices test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18

# single-host-SPMD
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 16 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 8 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18

PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
Expand Down Expand Up @@ -185,7 +188,6 @@ function run_torch_xla_cpp_tests() {
fi

if [ "$USE_COVERAGE" != "0" ]; then
# TODO(yeounoh) shard the coverage testing
if [ -x "$(command -v nvidia-smi)" ]; then
PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L""
cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov1.dat
Expand Down Expand Up @@ -232,7 +234,6 @@ function run_torch_xla_tests() {
export PYTORCH_TESTING_DEVICE_ONLY_FOR="xla"
export CXX_ABI=$(python -c "import torch;print(int(torch._C._GLIBCXX_USE_CXX11_ABI))")

# TODO(yeounoh) test coverage workflow is not parallelized.
if [[ -z "$RUN_BENCHMARK_TESTS" && -z "$RUN_CPP_TESTS1" && -z "$RUN_CPP_TESTS2" && -z "$RUN_PYTHON_TESTS" ]]; then
run_torch_xla_python_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE
run_torch_xla_cpp_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ jobs:
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
python3 setup.py install --user
- name: Install torchvision
run: |
cd pytorch/
pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
- name: Checkout PyTorch/XLA Repo
uses: actions/checkout@v4
with:
Expand All @@ -32,7 +36,6 @@ jobs:
PJRT_DEVICE: TPU
# Jax is needed for pallas tests.
run: |
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu --user
pip install fsspec
pip install rich
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Expand Down
2 changes: 0 additions & 2 deletions .kokoro/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ function install_environments() {
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" >> /etc/apt/sources.list.d/google-cloud-sdk.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -

# TODO(yeounoh) fix `GoogleCredentials` import error
apt-get update
apt-get -y install google-cloud-cli
pip install --upgrade google-api-python-client
Expand Down Expand Up @@ -95,7 +94,6 @@ function run_torch_xla_tests() {
pushd $XLA_DIR
echo "Running integration tests..."
# Run integration tests with CPU
# TODO(yeounoh) use custom GCP project for TPU
XLA_CUDA=0 USE_COVERAGE=0 XLA_SKIP_TORCH_OP_TESTS=1 XLA_SKIP_XRT_TESTS=1 CONTINUE_ON_ERROR=1 ./test/run_tests.sh
popd
}
1 change: 1 addition & 0 deletions .torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
release/2.3
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ http_archive(
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:quant_dequant_converter.diff",
],
strip_prefix = "xla-419a3d736bdfe6891bc40e9ab5c78b466c8e0dc6",
strip_prefix = "xla-18cbd2019898d3a7b563aeb73683f0c5a6ce14fd",
urls = [
"https://github.com/openxla/xla/archive/419a3d736bdfe6891bc40e9ab5c78b466c8e0dc6.tar.gz",
"https://github.com/openxla/xla/archive/18cbd2019898d3a7b563aeb73683f0c5a6ce14fd.tar.gz",
],
)

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/result_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def extract_metrics_csv(self, file, metric_df):
"repeat": dataline["repeat"],
"iterations_per_run": dataline["iterations_per_run"],
"error_message": None,
"outputs_file": dataline["outputs_file"],
"outputs_file": dataline["experiment"].get("outputs_file", ""),
}

if "error" in dataline["metrics"] and not self._args.hide_errors:
Expand Down
29 changes: 29 additions & 0 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,33 @@ generated_table = visualize_sharding(sharding, use_color=False)

You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`.

### Auto-Sharding
We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RFC](https://github.com/pytorch/xla/issues/6322). This is an experimental feature in `r2.3` and `nightly`, that supports `XLA:TPU` and a single TPUVM host.

PyTorch/XLA auto-sharding can be enabled by one of the following:
- Setting envvar `XLA_SPMD_AUTO=1`
- Calling the SPMD API in the beginning of your code:
```python
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
```
- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`:
```python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
```

Optionally, one can set the following options/env-vars to control the behvaior of
the XLA-based auto-sharding pass:
- `XLA_AUTO_USE_GROUP_SHARDING`: group resharding of the parameters. Set by default.
- `XLA_AUTO_SPMD_MESH`: logical mesh shape to be used for auto-sharding. For example,
`XLA_AUTO_SPMD_MESH=2,2` corresponds to a 2-by-2 mesh with 4 global devices. If unset,
a default device mesh shape of `num_devices,1` will be used.
3 changes: 3 additions & 0 deletions infra/ansible/e2e_tests.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ RUN ansible-playbook -vvv playbook.yaml -e "stage=release" -e "${ansible_vars}"
# Copy test sources.
RUN mkdir -p /src/pytorch/xla
COPY --from=build /src/pytorch/xla/test /src/pytorch/xla/test
# Copy ci_commit_pins from upstream
RUN mkdir -p /src/pytorch/.github
COPY --from=build /src/pytorch/.github/ci_commit_pins /src/pytorch/.github/ci_commit_pins

# Copy and install wheels.
WORKDIR /tmp/wheels
Expand Down
128 changes: 56 additions & 72 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index cd6e427c3..9deecdcb2 100644
index a75baa9dad..a09fca4898 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -68,6 +68,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
@@ -40,6 +40,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 5e5e32652..d246383bd 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -669,6 +669,71 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

+Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/xla/translate/hlo_to_mhlo/custom_call_importer.cc
index 9250747210..9cba4c992a 100644
--- a/xla/translate/hlo_to_mhlo/custom_call_importer.cc
+++ b/xla/translate/hlo_to_mhlo/custom_call_importer.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/AsmParser/AsmParser.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
@@ -93,6 +94,66 @@ absl::StatusOr<mlir::Operation*> ImportRealDynamicSliceOp(

} // namespace

+mlir::Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ std::vector<double> scales;
+ std::vector<int64_t> zero_points;
+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0;
+ Type storage_type, expressed_type;
+ mlir::Type storage_type, expressed_type;
+
+ auto scales_attr = backend_config.get("scale");
+ if (scales_attr) {
Expand Down Expand Up @@ -61,15 +66,11 @@ index 5e5e32652..d246383bd 100644
+ auto storage_type_attr = backend_config.get("storage_type");
+ if (storage_type_attr) {
+ storage_type = storage_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto expressed_type_attr = backend_config.get("expressed_type");
+ if (expressed_type_attr) {
+ expressed_type = expressed_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto is_signed = storage_type.cast<mlir::IntegerType>().isSignless();
Expand All @@ -85,54 +86,37 @@ index 5e5e32652..d246383bd 100644
+ }
+}
+
absl::StatusOr<mlir::Operation*> ImportCustomCallAsOp(
const HloCustomCallInstruction* instruction, mlir::Location loc,
mlir::Type result_type, mlir::ValueRange operands,
@@ -112,6 +173,30 @@ absl::StatusOr<mlir::Operation*> ImportCustomCallAsOp(
return ImportRealDynamicSliceOp(backend_config_str, loc, result_type,
operands, builder);
}
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportCustomCallAsOp(
const HloInstruction* instruction, mlir::Location loc,
const Type result_type, mlir::ValueRange operands,
@@ -992,6 +1057,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_quantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+ auto attr = mlir::parseAttribute(backend_config_str, builder->getContext())
+ .dyn_cast<mlir::DictionaryAttr>();
+ if (!attr) {
+ return Internal(
+ "Couldn't parse backend config into a dictionary attribute");
+ }
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call_target == "mhlo.uniform_quantize") {
+ return builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<mlir::RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_dequantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(
+ loc, result_type, operands) .getOperation();
+ }
}
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index f9edd1272..23a747fb1 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module,
module.getContext()->loadDialect<mlir::arith::ArithDialect>();
module.getContext()->loadDialect<mlir::func::FuncDialect>();
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>();
+ if (custom_call_target == "mhlo.uniform_dequantize") {
+ return builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(loc, result_type, operands)
+ .getOperation();
+ }
return InvalidArgument("Unsupported MHLO op custom_call %s",
custom_call_target);
}

namespace {
2 changes: 1 addition & 1 deletion plugins/cuda/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "torch_xla_cuda_plugin"
version = "0.0.1"
version = "2.3.0"
authors = [
{name = "PyTorch/XLA Dev Team", email = "pytorch-xla@googlegroups.com"},
]
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240213'
_date = '20240305'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.25.dev{_date}'
_jax_version = f'0.4.26.dev{_date}'


def _get_build_mode():
Expand Down
29 changes: 24 additions & 5 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,25 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
}

TEST_F(XLAShardingTest, CreateTensorsData) {
std::vector<at::Tensor> tensors(2);
if (torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") {
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
}

std::vector<at::Tensor> tensors(3);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices(2);
std::vector<std::string> devices(3);
std::fill_n(devices.begin(), devices.size(),
bridge::GetDefaultDevice()->toString());
std::vector<XLATensor::ShardingSpecPtr> shardings = {
nullptr, std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), tensor_shape)};
nullptr,
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), tensor_shape),
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Unknown().ToProto(), tensor_shape)};
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, shardings, devices);

Expand All @@ -337,7 +345,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(tensors_data[0], shards[0], at::kFloat));

// Returns multiple input shards, replicated
// Returns multiple input shards, explicitly replicated
auto sharded_xla_data =
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
tensors_data[1]);
Expand All @@ -347,6 +355,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));

// Returns multiple input shards, implicitly replicated
sharded_xla_data =
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
tensors_data[2]);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
sharded_xla_data);
EXPECT_EQ(shards.size(), n_devices);
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));
}
}

Expand Down
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
1 change: 1 addition & 0 deletions test/spmd/args_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def parse_common_options(datadir=None,
parser.add_argument('--async_closures', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--auto_spmd', action='store_true')
if opts:
for name, aopts in opts:
parser.add_argument(name, **aopts)
Expand Down
Loading
Loading