Skip to content

Commit

Permalink
Skip the execution if all Pending IRs are device data (#6642)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Mar 12, 2024
1 parent 8441943 commit 7096220
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 28 deletions.
41 changes: 35 additions & 6 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,36 @@ def test_random_op_different_result_each_run(self):
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))


class DynamoLTCInteractionTest(unittest.TestCase):

def index_copy_inplace(self, cache, update_indices, xk):
cache.index_copy_(0, update_indices, xk)

def test_mark_step_after_dynamo(self):
cache_len = 512
kv_heads = 8
head_dim = 128
running = 16

device = xm.xla_device()
cache = torch.rand((cache_len, kv_heads, head_dim)).to(device)
update_indices = torch.randint(
0, cache_len, (running,), dtype=torch.long).to(device)
xk = torch.rand((running, kv_heads, head_dim)).to(device)

dynamo_index_copy_inplace = torch.compile(
self.index_copy_inplace, backend="openxla", fullgraph=True)
met.clear_all()
for i in range(10):
dynamo_index_copy_inplace(cache, update_indices, xk)
xm.wait_device_ops()
current_execute_time = met.metric_data('ExecuteTime')[0]
# This mark_step should be a no-op and don't trigger additional execution.
xm.mark_step()
xm.wait_device_ops()
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])


class DynamoInferenceBasicTest(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -256,8 +286,8 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
# 2 compilations are caused by `t_xla` init and a no-op graph.
self.assertEqual(met.metric_data('CompileTime')[0], 2)
self.assertEqual(met.metric_data('ExecuteTime')[0], 2)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# Second tracing
met.clear_all()
Expand Down Expand Up @@ -295,8 +325,8 @@ def fn_fallback(t):
cpu_res = fn_fallback(t)
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 7)
self.assertEqual(met.metric_data('CompileTime')[0], 2)
self.assertEqual(met.metric_data('ExecuteTime')[0], 5)

# Second tracing
met.clear_all()
Expand Down Expand Up @@ -410,8 +440,7 @@ def test_resnet18(self):
# Graph 1: forward
# Graph 2: backward
# Graph 3: sync input for backward
# Graph 4: sync input for backward (TODO(JackCaoG) understand why there are two graphs)
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('CompileTime')[0], 3)
# We execute 3 graphs per step.
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3)
# one for each forward and one for each backward
Expand Down
12 changes: 6 additions & 6 deletions test/dynamo/test_dynamo_aliasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def dummy_mul(self, input):
def test_manual_buffer_donation(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
input_cloned = input.cpu().to(device)
dummy_inplace_mul_compiled = torch.compile(
self.dummy_inplace_mul, backend='openxla')

Expand All @@ -57,7 +57,7 @@ def test_manual_buffer_donation(self):
def test_manual_buffer_donation_for_non_inplce_op(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
input_cloned = input.cpu().to(device)
dummy_mul_compiled = torch.compile(self.dummy_mul, backend='openxla')

met.clear_all()
Expand All @@ -83,7 +83,7 @@ def dummy_inplace(input):

device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
input_cloned = input.cpu().to(device)
dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla')
xm.mark_step()
met.clear_all()
Expand Down Expand Up @@ -111,7 +111,7 @@ def dummy_add(self, input):
def test_manual_buffer_donation(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
input_cloned = input.cpu().to(device)
dummy_inplace_add_compiled = torch.compile(
self.dummy_inplace_add, backend='openxla')

Expand All @@ -128,7 +128,7 @@ def test_manual_buffer_donation(self):
def test_manual_buffer_donation_for_non_inplce_op(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
input_cloned = input.cpu().to(device)
dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla')

met.clear_all()
Expand All @@ -153,7 +153,7 @@ def dummy_inplace(input):

device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
input_cloned = input.cpu().to(device)
dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla')
xm.mark_step()
met.clear_all()
Expand Down
9 changes: 7 additions & 2 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,11 @@ def test_inplace_add_with_sharding(self):
'%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=',
hlo)

# avoid calling xr.addressable_device_count here otherwise it will init the test
# in non-spmd mode.
@unittest.skipIf(xr.device_type() == 'CPU',
"sharding will be the same for both tensors on single device"
)
def test_shard_hashing(self):
xt1 = torch.ones(2, 2).to(xm.xla_device())
xt2 = torch.ones(2, 2).to(xm.xla_device())
Expand All @@ -630,8 +635,8 @@ def test_shard_hashing(self):
self.assertTrue(torch.allclose(xt1 + 0, xt2 + 0))

# Check that hashes are different for the sharded and non-sharded tensors
hash1 = torch_xla._XLAC._get_graph_hash([xt1])
hash2 = torch_xla._XLAC._get_graph_hash([xt2])
hash1 = torch_xla._XLAC._get_graph_hash([xt1 + 0])
hash2 = torch_xla._XLAC._get_graph_hash([xt2 + 0])
self.assertNotEqual(hash1, hash2)

def test_transfer_sharded_data_to_host(self):
Expand Down
25 changes: 15 additions & 10 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,31 @@ class InputOutputAliasesTest(unittest.TestCase):

def test_non_view(self):
xla_device = xm.xla_device()
# This is a special case where we want to sync t1's and t2's
# value since they will have device_data ir instead of XLAData.
# HLO looks like
# ENTRY %IrToHlo.4 (p0.1: f32[4,2,2], p1.2: f32[4,2,2]) -> (f32[4,2,2], f32[4,2,2]) {
# %p0.1 = f32[4,2,2]{2,1,0} parameter(0)
# %p1.2 = f32[4,2,2]{2,1,0} parameter(1)
# ROOT %tuple.3 = (f32[4,2,2]{2,1,0}, f32[4,2,2]{2,1,0}) tuple(f32[4,2,2]{2,1,0} %p0.1, f32[4,2,2]{2,1,0} %p1.2)
# }
t1 = torch.randn(4, 2, 2).to(xla_device)
t2 = torch.randn(4, 2, 2).to(xla_device)
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
met.clear_all()

# check in place op aliasing.
t3 = t1 + t2
t1 *= 2.0
t2 += 2.0
xm.mark_step()

self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 4.0)
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

def test_aliasing_with_cloned(self):
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.randn(4, 2, 2).to(xla_device)
# t1_cloned share the same storage as t1
t1_cloned = torch.clone(t1)
t1 += 1
xm.mark_step()
# t1's storage will be alised with the ouput, need to make sure t1_cloned
# got a new buffer and is still valid.
torch.allclose(t1 - 1, t1_cloned)
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_short_metrics_report_custom_list(self):
xla_device = xm.xla_device()
t1 = torch.tensor(100, device=xla_device)
t2 = t1 * 2
t1 += 2
xm.mark_step()
t2_cpu = t2.cpu()
short_report = met.short_metrics_report(
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) {
data()->view = nullptr;
data()->tensor_data = c10::nullopt;
}
data()->is_cloned = false;
}

void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) {
Expand All @@ -319,6 +320,7 @@ void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) {
std::vector<XLATensorPtr> xtensors({c10::make_intrusive<XLATensor>(*this)});
XLAGraphExecutor::Get()->ApplyEagerSync(xtensors);
}
data()->is_cloned = false;
}

void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
Expand All @@ -335,6 +337,7 @@ void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {
<< (ir_value ? ir_value->ToString() : "empty node");
data()->ir_value = std::move(ir_value);
data()->generation += 1;
data()->is_cloned = false;
}

torch::lazy::Value XLATensor::GetIrValue() const {
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class XLATensor : public torch::lazy::LazyTensor {
// with unique_id, and then only get updated during the in-place
// op funtionalize pass to point to the input.
int64_t alias_id{0};
bool is_cloned = false;
};

static XLATensorPtr Create(const at::Tensor& tensor,
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,7 @@ XLATensorPtr clone(const XLATensorPtr& input) {
if (input->sharding_spec() != nullptr) {
cloned->SetShardingSpec(*input->sharding_spec());
}
cloned->data()->is_cloned = true;
return cloned;
}

Expand Down
25 changes: 21 additions & 4 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,6 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
torch::lazy::Value ir_value = tensors[i]->CurrentIrValue();
if (ir_value) {
if (ShouldSyncIrValue(ir_value)) {
// Add only tensors which need to be synced.
coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash());
coll.indices.push_back(i);

// `sharding_spec()` checks sharding equality. If IR node has no
// sharding, then sync XLATensor sharding to the IR node. XLATensor's
// sharding takes the precedence as the source of the truth.
Expand All @@ -635,6 +631,26 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
dynamic_cast<XlaNode*>(ir_value.node.get())
->SetSharding(sharding->sharding, ir_value.index);
}
auto device_data = torch_xla::DeviceData::Cast(ir_value.node.get());
// If current tensor is cloned from another tensor, we want to assign
// a new XlaData to it after current execution. Cloned tensor might
// share the same storage with the origional tensor but origional
// tensor might alias its storage with the output. It is safer to
// allocate a new buffer for the cloned tensor.
if (device_data != nullptr && !tensors[i]->data()->is_cloned) {
// current IR is a devicedata, we don't need to include it as a
// result of the computation. Call `GetXlaData` to extract the
// XlaData from the DeviceData Node and reset the IR. We also want
// to update XlaData's tensorID to make it match with the current
// XLATensor.
tensors[i]->GetXlaData()->SetInfo(
std::make_shared<LazyGraphExecutor::DeviceDataInfo>(
tensors[i]->GetUniqueId(), /*=read_only=*/false));
} else {
// Add only tensors which need to be synced.
coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash());
coll.indices.push_back(i);
}
}
} else if (config.force_ltc_data) {
// The tensor only has at::Tensor data. We need to queue it for a
Expand Down Expand Up @@ -965,6 +981,7 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::SetTensorData(
tensor->data()->handle = handle;
tensor->data()->view = nullptr;
tensor->data()->tensor_data = c10::nullopt;
tensor->data()->is_cloned = false;
}
tensors_data.emplace_back(std::move(handle));
}
Expand Down

0 comments on commit 7096220

Please sign in to comment.