Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport] Use np.prod instead of np.product (#7301) #7309

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'b604c8d87df842002a7a8de79a434026329fbcb2'
xla_hash = 'bf2dc9fe056bd7140e5f29a2ae6db15a26dd5443'

http_archive(
name = "xla",
Expand Down
58 changes: 58 additions & 0 deletions infra/tpu-pytorch-releases/artifacts.auto.tfvars
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,64 @@ nightly_builds = [

# Built on push to specific tag.
versioned_builds = [
# Remove libtpu from PyPI builds
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.8"
bundle_libtpu = "0"
},
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.9"
bundle_libtpu = "0"
},
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.10"
bundle_libtpu = "0"
},
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.11"
bundle_libtpu = "0"
},
# Bundle libtpu for Kaggle
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1+libtpu"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.10"
bundle_libtpu = "1"
},
{
git_tag = "v2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
accelerator = "cuda"
cuda_version = "12.1"
python_version = "3.8"
},
{
git_tag = "v2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
accelerator = "cuda"
cuda_version = "12.1"
python_version = "3.10"
},
# Remove libtpu from PyPI builds
{
git_tag = "v2.3.0"
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240605'
_date = '20240612'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.29.dev{_date}'
_jax_version = f'0.4.30.dev{_date}'


def _get_build_mode():
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_sharding_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@

num_devices = xr.global_runtime_device_count()

assert np.product(dcn_parallelism) * np.product(
assert np.prod(dcn_parallelism) * np.prod(
ici_parallelism) == num_devices, f"Number of devices {num_devices} \
does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}"
does not match the product of the parallelism {np.prod(dcn_parallelism) * np.prod(ici_parallelism)}"

# Use HybridMesh to optimize multislice topology
mesh = xs.HybridMesh(
Expand Down
28 changes: 28 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,34 @@ def test_resharding_different_device_mesh(self):
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_resharding_transpose_device_mesh(self):
dim = self.n_devices // 2
model1 = self._get_sharded_model(mesh_shape=(dim, self.n_devices // dim))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices // dim, dim))
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_padded_tensor(self):
# Use a linear layer with shape not divisible by the number of devices.
model1 = torch.nn.Linear(127, 63).to('xla')
model2 = torch.nn.Linear(127, 63).to('xla')
mesh = xs.Mesh(range(self.n_devices), (self.n_devices,))
# Transpose the sharding to induce resharding in the restore path
xs.mark_sharding(model1.weight, mesh, (0, None))
xs.mark_sharding(model2.weight, mesh, (None, 0))
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipUnless('CHKPT_PATH' in os.environ,
'CHKPT_PATH must be set for multihost checkpoint')
def test_multihost_checkpoint(self):
Expand Down
2 changes: 2 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,8 @@ def test(f, xshape, ishapes):
for xshape, i0shape, i1shape in cases[f2]:
test(f2, xshape, (i0shape, i1shape))

@unittest.skipIf(
True, "skip since https://github.com/pytorch/xla/pull/7130 is reverted")
def test_inplace_mul_scalar_different_dtype(self):
# This tests whether the returned output data-type agrees on PyTorch
# and XLA sides.
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _setup_default_env():
os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA')
# This is used for ML Framework Telemetry.
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', __version__)
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')

if tpu.version() == 4:
os.environ.setdefault('TPU_MEGACORE', 'megacore_dense')
Expand Down Expand Up @@ -212,7 +213,8 @@ def _init_xla_lazy_backend():
from .experimental import plugins
from ._internal import neuron, xpu # Additional built-in plugins

if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1':
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS',
'0' if _XLAC._has_cuda_support() else '1') == '1':
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

Expand Down
11 changes: 9 additions & 2 deletions torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
from torch_xla.experimental import plugins
from torch_xla.version import __version__

_GCE_METADATA_ROOT_URL = 'http://metadata.google.internal/computeMetadata/v1'
_ACCELERATOR_TYPE_TO_HOST_BOUNDS = {
Expand Down Expand Up @@ -342,10 +343,16 @@ def configure_multiprocess(self, local_rank, local_world_size):
return configure_topology(local_rank, local_world_size)

def physical_chip_count(self):
return num_available_chips()
# HACK: We may reduce the number of processes we spawn depending on TPU
# topology settings
return num_local_processes()

def client_create_options(self):
return {
'max_inflight_computations':
xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4)
xu.getenv_as('XLA_TPU_MAX_INFLIGHT_COMPUTATIONS', int, 4),
'ml_framework_name':
'PyTorch/XLA',
'ml_framework_version':
__version__
}
13 changes: 5 additions & 8 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2221,14 +2221,11 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output,
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
const at::Tensor& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&,
std::optional<at::ScalarType>);
return OpConfig::From(static_cast<FnType*>(tensor_methods::mul))
.add_input(self)
.add_input(other)
.cast_inputs_to_common_dtype()
.use_opmathtype_for_compute()
.run();
return DoBinaryOp(self, other,
[&](const XLATensorPtr& xself, const XLATensorPtr& xother,
at::ScalarType dtype) {
return tensor_methods::mul(xself, xother, dtype);
});
}

at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2423,6 +2423,13 @@ void InitXlaModuleBindings(py::module m) {
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/true);
});
m.def("_has_cuda_support", []() {
#ifdef GOOGLE_CUDA
return true;
#else
return false;
#endif
});
m.def("_xla_gpu_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ InitializePjRt(const std::string& device_type) {
if (plugin) {
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;

// Init the absl logging to avoid the log spam.
absl::InitializeLog();

std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr;
if (plugin->requires_xla_coordinator()) {
int local_process_rank = sys_util::GetEnvInt(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _create_device_mesh_for_nd_torus(
indices = itertools.combinations(
range(len(assignable_physical_mesh)), num_axes)
for c_axes, c_indices in zip(axes, indices):
if np.product(c_axes) == logical_axis_size:
if np.prod(c_axes) == logical_axis_size:
assignment[logical_axis_index] = c_indices
# Zero the assigned physical axes.
assignable_physical_mesh = [
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/experimental/distributed_checkpoint/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,6 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
lengths and offsets into the global tensor.
"""
offsets = read_item.dest_offsets
index = read_item.dest_index
if index.fqn in self.sharded_state_dict:
# Update offsets to index into the shard rather than the global tensor
shard = self._local_shards[index.fqn][index.index]
offsets = torch.Size(d - i.start for d, i in zip(offsets, shard.indices))
return narrow_tensor_by_index(tensor, offsets, read_item.lengths)

def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
Expand Down
Loading