Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize execution for ops that have multiple output in eager mode #7680

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/eager/test_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

import torch.nn as nn


class Eager(unittest.TestCase):
Expand All @@ -14,6 +17,7 @@ def setUpClass(cls):
torch_xla.experimental.eager_mode(True)

def test_eager_basic(self):
xm.wait_device_ops()
met.clear_all()
self.assertTrue(torch_xla.experimental.is_eager_mode())
device = torch_xla.device()
Expand Down Expand Up @@ -89,6 +93,38 @@ def test_eager_set_random_seed(self):
t2 = torch.randn(12, 13, device=device)
self.assertTrue(torch.allclose(t1.cpu(), t2.cpu()))

def test_batch_norm_execute_once(self):
xm.wait_device_ops()
device = torch_xla.device()
m = nn.BatchNorm2d(16).to(device)
m.train()
input = torch.randn(8, 16, 8, 32).to(device)
met.clear_all()
output = m(input)
self.assertIn('xla::native_batch_norm', met.counter_names())
# native_batch_norm has 5 outputs, we should execute them in one graph.
# However in the egaer mode there will be 2 more graphs to broadcast
# 0 to the `u8[0]` and `f32[8,16,8,32]` and one more graph to generate
# s64[].
self.assertLessEqual(met.metric_data('EagerOpExecuteTime')[0], 4)
# make sure running_mean becomes a XLA_Data
self.assertIn('Data Shape: f32[16]',
torch_xla._XLAC._get_xla_tensor_debug_info(m.running_mean))

def test_svd_execute_once(self):
device = torch_xla.device()
a = torch.randn(5, 3).to(device)
xm.wait_device_ops()
met.clear_all()
u, s, v = torch.svd(a)
self.assertIn('xla::_linalg_svd', met.counter_names())
# svd has 3 outputs, we should execute them in one graph. However in the
# eager mode there will 1 more graph to create an empty f32[5,3] and 2 more
# graphs to transpose the results.
self.assertLessEqual(met.metric_data('EagerOpExecuteTime')[0], 4)
self.assertIn('Data Shape: f32[5,3]',
torch_xla._XLAC._get_xla_tensor_debug_info(u))


if __name__ == '__main__':
test = unittest.main()
Expand Down
24 changes: 16 additions & 8 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ XLATensorPtr XLATensor::Create(

XLATensorPtr XLATensor::Create(
torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device,
std::optional<at::ScalarType> logical_element_type) {
std::optional<at::ScalarType> logical_element_type,
bool delay_eager_executation) {
XLATensorPtr xtensor = c10::make_intrusive<XLATensor>(
XLATensor(std::move(ir_value), device, logical_element_type));
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
graph_executor->RegisterTensor(xtensor->data());
if (UseEagerDebugMode() || graph_executor->UseEagerMode()) {
if ((UseEagerDebugMode() || graph_executor->UseEagerMode()) &&
!delay_eager_executation) {
std::vector<XLATensorPtr> xtensors({xtensor});
graph_executor->ApplyEagerSync(xtensors);
}
Expand Down Expand Up @@ -620,26 +622,32 @@ torch::lazy::Value XLATensor::MaybeCastIrValue(
return ir_value;
}

XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value) const {
XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value,
bool delay_eager_executation) const {
ir_value = MaybeCastIrValue(std::move(ir_value), GetDevice(),
/*logical_element_type=*/std::nullopt);
return Create(std::move(ir_value), GetDevice(), dtype_optional());
return Create(std::move(ir_value), GetDevice(), dtype_optional(),
delay_eager_executation);
}

XLATensorPtr XLATensor::CreateFrom(
torch::lazy::Value ir_value,
std::optional<at::ScalarType> logical_element_type_opt) const {
std::optional<at::ScalarType> logical_element_type_opt,
bool delay_eager_executation) const {
ir_value = MaybeCastIrValue(std::move(ir_value), GetDevice(),
logical_element_type_opt);
return Create(std::move(ir_value), GetDevice(), logical_element_type_opt);
return Create(std::move(ir_value), GetDevice(), logical_element_type_opt,
delay_eager_executation);
}

XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value,
const torch::lazy::BackendDevice& device,
at::ScalarType logical_element_type) const {
at::ScalarType logical_element_type,
bool delay_eager_executation) const {
ir_value =
MaybeCastIrValue(std::move(ir_value), device, logical_element_type);
return Create(std::move(ir_value), device, logical_element_type);
return Create(std::move(ir_value), device, logical_element_type,
delay_eager_executation);
}

void XLATensor::ApplyPendingGraph() {
Expand Down
12 changes: 8 additions & 4 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,23 @@ class XLATensor : public torch::lazy::LazyTensor {
std::optional<at::ScalarType> logical_element_type = std::nullopt);
static XLATensorPtr Create(
torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device,
std::optional<at::ScalarType> logical_element_type = std::nullopt);
std::optional<at::ScalarType> logical_element_type = std::nullopt,
bool delay_eager_executation = false);
static XLATensorPtr Create(std::shared_ptr<Data> data);

// Create a new XLA tensor with the same metadata of the input tensor (with
// possible overrides), and the new IR value.
XLATensorPtr CreateFrom(torch::lazy::Value ir_value) const;
XLATensorPtr CreateFrom(torch::lazy::Value ir_value,
bool delay_eager_executation = false) const;
XLATensorPtr CreateFrom(
torch::lazy::Value ir_value,
std::optional<at::ScalarType> logical_element_type_opt) const;
std::optional<at::ScalarType> logical_element_type_opt,
bool delay_eager_executation = false) const;
// TODO: We should remove this one once MaybeCastIrValue is no longer needed.
XLATensorPtr CreateFrom(torch::lazy::Value ir_value,
const torch::lazy::BackendDevice& device,
at::ScalarType logical_element_type) const;
at::ScalarType logical_element_type,
bool delay_eager_executation = false) const;

// The default ctor previously created a null LazyTensor (one with no 'data'
// obj). Creating a null XLATensor is no longer possible, since the same can
Expand Down
Loading
Loading