From 9703109184a84ce9a4ed0431693402ded7a2f4de Mon Sep 17 00:00:00 2001 From: "Jin Young (Daniel) Sohn" Date: Thu, 3 Sep 2020 11:41:04 -0700 Subject: [PATCH] Final cherry-pick for r1.6 release (#2479) Co-authored-by: Davide Libenzi Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Co-authored-by: JackCaoG --- docs/source/index.rst | 1 + test/pytorch_test_base.py | 2 ++ test/test_operations.py | 14 ++++++++++++++ third_party/tensorflow | 2 +- torch_xla/__init__.py | 9 +++++---- torch_xla/core/xla_model.py | 3 --- torch_xla/csrc/reduction.cpp | 14 +++++++++++--- torch_xla/distributed/xla_multiprocessing.py | 8 ++++++-- 8 files changed, 40 insertions(+), 13 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 545aee85872..072084665a7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -16,6 +16,7 @@ xla_model .. autofunction:: xrt_world_size .. autofunction:: all_reduce .. autofunction:: all_gather +.. autofunction:: all_to_all .. autofunction:: add_step_closure .. autofunction:: wait_device_ops .. autofunction:: optimizer_step diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 58d75e7c92e..da024c3a30b 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -169,6 +169,8 @@ 'test_masked_select_mem_overlap', # doesn't raise 'test_scatter_mem_overlap', # doesn't raise 'test_index_mem_overlap', # doesn't raise + 'test_topk_nonfinite_xla_float32', # TFXLA update HLO changed for 1.6 + 'test_topk_nonfinite_xla_float64', # TFXLA update HLO changed for 1.6 }, 'TestViewOpsXLA': { 'test_contiguous_nonview', diff --git a/test/test_operations.py b/test/test_operations.py index 706f134799f..9a7e47fbea0 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -663,6 +663,20 @@ def test_get_xla_tensor(self): self.assertEqual(tx, sx.data.cpu()) +class TestBinaryCrossEntropyLimitValue(XlaTestCase): + + def test_cross_entropy_loss(self): + + def test_fn(pred, target): + lossfn = nn.BCELoss() + return lossfn(pred, target) + + pred = torch.tensor(1.0) + target = torch.tensor(1.0) + for offset in [1, 0, 1e-8, 1e-7]: + self.runAtenTest([pred - offset, target], test_fn) + + class TestDynamicShape(XlaTestCase): def test_nonzero_shape(self): diff --git a/third_party/tensorflow b/third_party/tensorflow index 44067f0783c..21133c9daff 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit 44067f0783c56ad092f6ef5ea1034e6926559d86 +Subproject commit 21133c9daffe5fd991d45359f97bf0be642ecd8b diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 7f4cb6da04d..2b7a311cf47 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -4,7 +4,7 @@ import socket import time -from .version import __version__ as version +from .version import __version__ def _maybe_select_tpu_version(): @@ -40,9 +40,10 @@ def _wait_for_open(version, timeout=100, interval=10, log=True): import cloud_tpu_client client = cloud_tpu_client.Client(tpu_name) - client.configure_tpu_version(f'pytorch-{version}', restart_type='ifNeeded') + client.configure_tpu_version( + f'pytorch-{__version__}', restart_type='ifNeeded') # client.wait_for_healthy() API doesn't work as we dont have TPU API access - _wait_for_open(version) + _wait_for_open(__version__) except ImportError: logging.warning(( 'Not selecting corresponding TPU runtime since cloud_tpu_client is not ' @@ -50,7 +51,7 @@ def _wait_for_open(version, timeout=100, interval=10, log=True): except Exception: # This path is hit, when we get throttled by the verison changer # when we import torch_xla from xmp.spawn-ed processes. - _wait_for_open(version, log=False) + _wait_for_open(__version__, log=False) def _setup_grpc(): diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 13dc18d2729..e2ae87bc9d4 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -478,9 +478,6 @@ def all_to_all(value, groups=None): """Performs an XLA `AllToAll()` operation on the input tensor. - WARNING: This function is not very reliable, may produce wrong results under - certain inputs. Use it at your own risk. - See: https://www.tensorflow.org/xla/operation_semantics#alltoall Args: diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index e339f177808..20e1330d90d 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -127,6 +127,7 @@ xla::XlaOp CreateProduct(xla::XlaOp input, xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, const absl::optional& weight, ReductionMode reduction) { + static const float kLogBound = -100; const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp xweight; if (weight) { @@ -137,8 +138,11 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, XlaHelpers::ScalarBroadcast(1.0, input_shape, target.builder()); } xla::XlaOp one = xla::One(input.builder(), input_shape.element_type()); - xla::XlaOp result = -xweight * (target * xla::Log(input) + - (one - target) * xla::Log(one - input)); + xla::XlaOp log_bound = XlaHelpers::ScalarValue( + kLogBound, input_shape.element_type(), input.builder()); + xla::XlaOp result = + -xweight * (target * xla::Max(xla::Log(input), log_bound) + + (one - target) * xla::Max(xla::Log(one - input), log_bound)); if (reduction == ReductionMode::kNone) { return result; } @@ -154,6 +158,7 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, xla::XlaOp BuildBinaryCrossEntropyBackward( xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target, const absl::optional& weight, ReductionMode reduction) { + static const float kEpsilon = 1e-12; const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp xweight; if (weight) { @@ -164,7 +169,10 @@ xla::XlaOp BuildBinaryCrossEntropyBackward( XlaHelpers::ScalarBroadcast(1.0, input_shape, target.builder()); } xla::XlaOp one = xla::One(input.builder(), input_shape.element_type()); - xla::XlaOp result = xweight * (input - target) / input / (one - input); + xla::XlaOp epsilon = XlaHelpers::ScalarValue( + kEpsilon, input_shape.element_type(), input.builder()); + xla::XlaOp result = + xweight * (input - target) / xla::Max(input * (one - input), epsilon); if (reduction == ReductionMode::kNone) { return result * grad_output; } diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index ba9709a5b0d..abf6c346bb1 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -226,9 +226,13 @@ def _start_fn(index, pf_cfg, fn, args): # Calling _setup_replication() will trigger XLA library initialization, so the # environment must be fully setup before doing so. _setup_replication() + fn(gindex, *args) + + +def _mp_start_fn(index, pf_cfg, fn, args): exit_code = 0 try: - fn(gindex, *args) + _start_fn(index, pf_cfg, fn, args) except Exception as e: print( 'Exception in device={}: {}'.format(_get_multiprocessing_device(), @@ -288,7 +292,7 @@ def spawn(fn, _start_fn(0, pf_cfg, fn, args) else: return torch.multiprocessing.start_processes( - _start_fn, + _mp_start_fn, args=(pf_cfg, fn, args), nprocs=pf_cfg.num_devices, join=join,