diff --git a/WORKSPACE b/WORKSPACE index 155cc74731b..faed0ceb57b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,7 +50,7 @@ new_local_repository( # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -xla_hash = 'b604c8d87df842002a7a8de79a434026329fbcb2' +xla_hash = 'bf2dc9fe056bd7140e5f29a2ae6db15a26dd5443' http_archive( name = "xla", diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index c2617739f45..2e8db7cc7b3 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -33,6 +33,64 @@ nightly_builds = [ # Built on push to specific tag. versioned_builds = [ + # Remove libtpu from PyPI builds + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.8" + bundle_libtpu = "0" + }, + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.9" + bundle_libtpu = "0" + }, + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "0" + }, + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.11" + bundle_libtpu = "0" + }, + # Bundle libtpu for Kaggle + { + git_tag = "v2.4.0-rc1" + package_version = "2.4.0-rc1+libtpu" + pytorch_git_rev = "v2.4.0-rc1" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "1" + }, + { + git_tag = "v2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + accelerator = "cuda" + cuda_version = "12.1" + python_version = "3.8" + }, + { + git_tag = "v2.4.0-rc1" + pytorch_git_rev = "v2.4.0-rc1" + package_version = "2.4.0-rc1" + accelerator = "cuda" + cuda_version = "12.1" + python_version = "3.10" + }, # Remove libtpu from PyPI builds { git_tag = "v2.3.0" diff --git a/setup.py b/setup.py index 7b6ecb4af6f..098a5eb20cb 100644 --- a/setup.py +++ b/setup.py @@ -64,10 +64,10 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240605' +_date = '20240612' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' -_jax_version = f'0.4.29.dev{_date}' +_jax_version = f'0.4.30.dev{_date}' def _get_build_mode(): diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index 849b45c7c98..4f869961f09 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -67,9 +67,9 @@ num_devices = xr.global_runtime_device_count() -assert np.product(dcn_parallelism) * np.product( +assert np.prod(dcn_parallelism) * np.prod( ici_parallelism) == num_devices, f"Number of devices {num_devices} \ - does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}" + does not match the product of the parallelism {np.prod(dcn_parallelism) * np.prod(ici_parallelism)}" # Use HybridMesh to optimize multislice topology mesh = xs.HybridMesh( diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index a035a3f11bd..593dba0769b 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -172,6 +172,34 @@ def test_resharding_different_device_mesh(self): save_planner=SPMDSavePlanner(), load_planner=SPMDLoadPlanner()) + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed to change mesh") + def test_resharding_transpose_device_mesh(self): + dim = self.n_devices // 2 + model1 = self._get_sharded_model(mesh_shape=(dim, self.n_devices // dim)) + model2 = self._get_sharded_model(mesh_shape=(self.n_devices // dim, dim)) + self._save_and_restore( + model1, + model2, + save_planner=SPMDSavePlanner(), + load_planner=SPMDLoadPlanner()) + + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed to change mesh") + def test_padded_tensor(self): + # Use a linear layer with shape not divisible by the number of devices. + model1 = torch.nn.Linear(127, 63).to('xla') + model2 = torch.nn.Linear(127, 63).to('xla') + mesh = xs.Mesh(range(self.n_devices), (self.n_devices,)) + # Transpose the sharding to induce resharding in the restore path + xs.mark_sharding(model1.weight, mesh, (0, None)) + xs.mark_sharding(model2.weight, mesh, (None, 0)) + self._save_and_restore( + model1, + model2, + save_planner=SPMDSavePlanner(), + load_planner=SPMDLoadPlanner()) + @unittest.skipUnless('CHKPT_PATH' in os.environ, 'CHKPT_PATH must be set for multihost checkpoint') def test_multihost_checkpoint(self): diff --git a/test/test_operations.py b/test/test_operations.py index 938817a6fd2..6fb0b79d78d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2099,6 +2099,8 @@ def test(f, xshape, ishapes): for xshape, i0shape, i1shape in cases[f2]: test(f2, xshape, (i0shape, i1shape)) + @unittest.skipIf( + True, "skip since https://github.com/pytorch/xla/pull/7130 is reverted") def test_inplace_mul_scalar_different_dtype(self): # This tests whether the returned output data-type agrees on PyTorch # and XLA sides. diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 5bfb7b8991b..983bd46d679 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -77,6 +77,7 @@ def _setup_default_env(): os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # This is used for ML Framework Telemetry. os.environ.setdefault('TPU_ML_PLATFORM_VERSION', __version__) + os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') if tpu.version() == 4: os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') @@ -212,7 +213,8 @@ def _init_xla_lazy_backend(): from .experimental import plugins from ._internal import neuron, xpu # Additional built-in plugins -if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1': +if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS', + '0' if _XLAC._has_cuda_support() else '1') == '1': plugins.use_dynamic_plugins() plugins.register_installed_plugins() diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 241bf469d4a..8a42665012a 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -17,6 +17,7 @@ import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm from torch_xla.experimental import plugins +from torch_xla.version import __version__ _GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1' _ACCELERATOR_TYPE_TO_HOST_BOUNDS = { @@ -342,10 +343,16 @@ def configure_multiprocess(self, local_rank, local_world_size): return configure_topology(local_rank, local_world_size) def physical_chip_count(self): - return num_available_chips() + # HACK: We may reduce the number of processes we spawn depending on TPU + # topology settings + return num_local_processes() def client_create_options(self): return { 'max_inflight_computations': - xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4) + xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4), + 'ml_framework_name': + 'PyTorch/XLA', + 'ml_framework_version': + __version__ } diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index dc30734756d..3459b8935e8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2221,14 +2221,11 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&, - std::optional); - return OpConfig::From(static_cast(tensor_methods::mul)) - .add_input(self) - .add_input(other) - .cast_inputs_to_common_dtype() - .use_opmathtype_for_compute() - .run(); + return DoBinaryOp(self, other, + [&](const XLATensorPtr& xself, const XLATensorPtr& xother, + at::ScalarType dtype) { + return tensor_methods::mul(xself, xother, dtype); + }); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a6e3196c5d2..3fba13773b3 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2423,6 +2423,13 @@ void InitXlaModuleBindings(py::module m) { return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, /*is_tpu=*/true); }); + m.def("_has_cuda_support", []() { +#ifdef GOOGLE_CUDA + return true; +#else + return false; +#endif + }); m.def("_xla_gpu_custom_call", [](const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 52b06d89cb4..e92dcf7dd44 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -82,6 +82,9 @@ InitializePjRt(const std::string& device_type) { if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + // Init the absl logging to avoid the log spam. + absl::InitializeLog(); + std::shared_ptr kv_store = nullptr; if (plugin->requires_xla_coordinator()) { int local_process_rank = sys_util::GetEnvInt( diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 916fb56f7c9..73e41e01fba 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -268,7 +268,7 @@ def _create_device_mesh_for_nd_torus( indices = itertools.combinations( range(len(assignable_physical_mesh)), num_axes) for c_axes, c_indices in zip(axes, indices): - if np.product(c_axes) == logical_axis_size: + if np.prod(c_axes) == logical_axis_size: assignment[logical_axis_index] = c_indices # Zero the assigned physical axes. assignable_physical_mesh = [ diff --git a/torch_xla/experimental/distributed_checkpoint/planners.py b/torch_xla/experimental/distributed_checkpoint/planners.py index c417872c2f2..32fe987a97d 100644 --- a/torch_xla/experimental/distributed_checkpoint/planners.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -282,11 +282,6 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): lengths and offsets into the global tensor. """ offsets = read_item.dest_offsets - index = read_item.dest_index - if index.fqn in self.sharded_state_dict: - # Update offsets to index into the shard rather than the global tensor - shard = self._local_shards[index.fqn][index.index] - offsets = torch.Size(d - i.start for d, i in zip(offsets, shard.indices)) return narrow_tensor_by_index(tensor, offsets, read_item.lengths) def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: