Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

usability #7728

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.core.functions as xf
import torch_xla.debug.profiler as xp
import torch_xla._internal.utils as _utils
import unittest
import test_utils

Expand Down Expand Up @@ -210,7 +211,7 @@ def check_fn(v):
return False
return True

xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
xla_data = _utils.ToXlaTensorArena(convert_fn, select_fn).transform(data)
self.assertTrue(check_fn(xla_data))


Expand Down
3 changes: 2 additions & 1 deletion test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
import torch_xla._internal.utils as _utils
from torch_xla import runtime as xr
from copy import deepcopy

Expand All @@ -28,7 +29,7 @@ def convert_fn(tensors):
def select_fn(v):
return type(v) == torch.Tensor and xm.is_xla_tensor(v)

return xm.ToXlaTensorArena(convert_fn, select_fn).transform(s)
return _utils.ToXlaTensorArena(convert_fn, select_fn).transform(s)


class XlaZeRO1Test(test_utils.XlaTestCase):
Expand Down
85 changes: 85 additions & 0 deletions torch_xla/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,93 @@
import logging
import re
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm


def parse_xla_device(device: str):
m = re.match(r'([A-Z]+):(\d+)$', device)
if m:
return (m.group(1), int(m.group(2)))


def reduce_gradients(optimizer, groups=None, pin_layout=True):
"""Reduces all the gradients handled by an optimizer.

Args:
optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance
containing the gradients to be reduced.
groups (list, optional): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout when reducing gradients.
See `xm.all_reduce` for details.
"""
count = xm.xrt_world_size()
if count > 1:
gradients = _fetch_gradients(optimizer)
bucket_cap_mb = int(os.getenv('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', 0))
# Reverse the gradients list so that we start allreduce from the last layer
# onwards. This allows allreduce to trigger as soon as the bucket fills up and
# overlap with backward pass.
if bucket_cap_mb > 0:
gradients = reversed(gradients)
xm.all_reduce_bucketized_gradients(
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout,
bucket_cap_mb=bucket_cap_mb)
else:
xm.all_reduce(
REDUCE_SUM,
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)


class ToXlaTensorArena(object):

def __init__(self, convert_fn, select_fn):
self._convert_fn = convert_fn
self._select_fn = select_fn
self._tensors = []

def _add(self, tensor):
self._tensors.append(tensor)

def _convert(self):
self._index = 0
if self._tensors:
self._converted_tensors = self._convert_fn(self._tensors)
else:
self._converted_tensors = []

def _get_converted_tensor(self):
assert self._index < len(self._converted_tensors)
new_tensor = self._converted_tensors[self._index]
self._index += 1
return new_tensor

def _collect_tensors(self, inputs):

def collect_fn(value):
self._add(value)

xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)

def _replace_tensors(self, inputs):

def convert_fn(value):
return self._get_converted_tensor()

return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
convert_fn)

def transform(self, inputs):
self._tensors = []
self._collect_tensors(inputs)
self._convert()
return self._replace_tensors(inputs)
91 changes: 5 additions & 86 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from . import xla_model as this_module
parse_xla_device = deprecated(this_module, _utils.parse_xla_device)
reduce_gradients = deprecated(this_module, _utils.reduce_gradients)
ToXlaTensorArena = deprecated(this_module, _utils.ToXlaTensorArena)


def _init_world_size_ordinal():
Expand Down Expand Up @@ -352,51 +354,6 @@ def global_rate(self):
return count / delta if delta > 0 else 0.0


class ToXlaTensorArena(object):

def __init__(self, convert_fn, select_fn):
self._convert_fn = convert_fn
self._select_fn = select_fn
self._tensors = []

def _add(self, tensor):
self._tensors.append(tensor)

def _convert(self):
self._index = 0
if self._tensors:
self._converted_tensors = self._convert_fn(self._tensors)
else:
self._converted_tensors = []

def _get_converted_tensor(self):
assert self._index < len(self._converted_tensors)
new_tensor = self._converted_tensors[self._index]
self._index += 1
return new_tensor

def _collect_tensors(self, inputs):

def collect_fn(value):
self._add(value)

xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)

def _replace_tensors(self, inputs):

def convert_fn(value):
return self._get_converted_tensor()

return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
convert_fn)

def transform(self, inputs):
self._tensors = []
self._collect_tensors(inputs)
self._convert()
return self._replace_tensors(inputs)


def check_view_sharing(obj):
tensors = set()
aliases = dict()
Expand Down Expand Up @@ -1158,44 +1115,6 @@ def all_reduce_bucketized_gradients(gradients,
pin_layout=pin_layout)


def reduce_gradients(optimizer, groups=None, pin_layout=True):
"""Reduces all the gradients handled by an optimizer.

Args:
optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance
containing the gradients to be reduced.
groups (list, optional): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
pin_layout (bool, optional): whether to pin the layout when reducing gradients.
See `xm.all_reduce` for details.
"""
count = xrt_world_size()
if count > 1:
gradients = _fetch_gradients(optimizer)
bucket_cap_mb = int(os.getenv('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', 0))
# Reverse the gradients list so that we start allreduce from the last layer
# onwards. This allows allreduce to trigger as soon as the bucket fills up and
# overlap with backward pass.
if bucket_cap_mb > 0:
gradients = reversed(gradients)
all_reduce_bucketized_gradients(
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout,
bucket_cap_mb=bucket_cap_mb)
else:
all_reduce(
REDUCE_SUM,
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)


def optimizer_step(optimizer,
barrier=False,
optimizer_args={},
Expand Down Expand Up @@ -1225,7 +1144,7 @@ def optimizer_step(optimizer,
Returns:
The same value returned by the `optimizer.step()` call.
"""
reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
_utils.reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
loss = optimizer.step(**optimizer_args)
if barrier:
mark_step()
Expand Down Expand Up @@ -1281,7 +1200,7 @@ def convert_fn(tensors):
def select_fn(v):
return type(v) == torch.Tensor and is_xla_tensor(v)

return ToXlaTensorArena(convert_fn, select_fn).transform(data)
return _utils.ToXlaTensorArena(convert_fn, select_fn).transform(data)


def send_cpu_data_to_device(datas, device, input_sharding=None):
Expand All @@ -1300,7 +1219,7 @@ def select_fn(v):

if type(datas) is torch.Tensor:
datas = [datas]
return ToXlaTensorArena(convert_fn, select_fn).transform(datas)
return _utils.ToXlaTensorArena(convert_fn, select_fn).transform(datas)


def xla_rendezvous(payload: bytes = b'',
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def select_fn(v):
if os.path.isdir(path):
shutil.rmtree(path)
os.mkdir(path)
return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
return _utils.ToXlaTensorArena(convert_fn, select_fn).transform(data)


def save(data, path, master_only=True, global_master=False):
Expand Down Expand Up @@ -97,4 +97,4 @@ def convert_fn(tensors):
def select_fn(v):
return type(v) == TensorReference

return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)
return _utils.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)
Loading