Skip to content

Commit

Permalink
backport Enable eager mode for PyTorch/XLA (#7611)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jul 9, 2024
1 parent 12e0d81 commit 7714e78
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 22 deletions.
4 changes: 4 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ spmd
.. autoclass:: HybridMesh
.. autoclass:: ShardingSpec

experimental
----------------------------------
.. automodule:: torch_xla.experimental
.. autofunction:: eager_mode

debug
----------------------------------
Expand Down
12 changes: 12 additions & 0 deletions examples/eager/train_decoder_only_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import torch_xla

if __name__ == '__main__':
torch_xla.experimental.eager_mode(True)
base = TrainDecoderOnlyBase()
base.start_training()
12 changes: 12 additions & 0 deletions test/debug_tool/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def setUpClass(cls):
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
open(cls.debug_file_name, 'w').close()

def test_eager_mark_step(self):
torch_xla.experimental.eager_mode(True)
device = xm.xla_device()
t1 = torch.randn(5, 9, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
# We expect PT_XLA_BUDEG not to output anything under the eager mode
self.assertEqual(len(lines), 0)
torch_xla.experimental.eager_mode(False)
open(self.debug_file_name, 'w').close()

def test_user_mark_step(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
Expand Down
19 changes: 19 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ def test_tracing_time_metrics(self):
self.assertIn('LazyTracing', met.metric_names())
self.assertGreater(met.metric_data('LazyTracing')[0], 1)

def test_eager_metrics(self):
torch_xla.experimental.eager_mode(True)
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.tensor(156, device=xla_device)
t2 = t1 + 100
xm.wait_device_ops()
self.assertIn('EagerOpCompileTime', met.metric_names())
# one for cosntant, one for add
self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2)
self.assertIn('EagerOpExecuteTime', met.metric_names())
# one for add
self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2)
# mark_step should be a no-op
xm.mark_step()
self.assertNotIn('CompileTime', met.metric_names())
self.assertNotIn('ExecuteTime', met.metric_names())
torch_xla.experimental.eager_mode(False)

def test_short_metrics_report_default_list(self):
xla_device = xm.xla_device()
t1 = torch.tensor(1456, device=xla_device)
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ python3 examples/data_parallel/train_resnet_xla_ddp.py
python3 examples/fsdp/train_decoder_only_fsdp_v2.py
python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py
python3 examples/train_resnet_amp.py
python3 examples/eager/train_decoder_only_eager.py
13 changes: 13 additions & 0 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ void DebugUtil::analyze_graph_execution_python_frame(
return;
}

// don't output analysis for eager mode execution/compilation
if (XLAGraphExecutor::Get()->UseEagerMode() &&
source != GraphAnalysisSource::DynamoExecution) {
return;
}

if (pt_xla_debug_level <= 1 && source != GraphAnalysisSource::Compilation) {
// for debug level <=1, only output compilation analysis in this function.
return;
Expand Down Expand Up @@ -385,6 +391,13 @@ void DebugUtil::post_compilation_analysis(
if (pt_xla_debug_level <= 0 || !is_master_process) {
return;
}

// don't output analysis for eager mode execution/compilation.
// TODO(JackCaoG): enable this for eager+dynamo
if (XLAGraphExecutor::Get()->UseEagerMode()) {
return;
}

static const std::string debug_output_prefix = "Post Compilation Analysis: ";
std::stringstream ss;
ss << "\n"
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2403,6 +2403,11 @@ void InitXlaModuleBindings(py::module m) {
});
m.def("_get_xla_enable_device_data_cache",
[]() { return FLAGS_torch_lazy_enable_device_data_cache; });
m.def("_set_use_eager_mode", [](bool use_eager_mode) {
XLAGraphExecutor::Get()->SetUseEagerMode(use_eager_mode);
});
m.def("_get_use_eager_mode",
[]() { return XLAGraphExecutor::Get()->UseEagerMode(); });
m.def("_replace_xla_tensor",
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/runtime/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,24 @@ metrics::Metric* ComputationClient::CompileMetric() {
return metric;
}

metrics::Metric* ComputationClient::EagerCompileMetric() {
static metrics::Metric* metric =
new metrics::Metric("EagerOpCompileTime", metrics::MetricFnTime);
return metric;
}

metrics::Metric* ComputationClient::ExecuteMetric() {
static metrics::Metric* metric =
new metrics::Metric("ExecuteTime", metrics::MetricFnTime);
return metric;
}

metrics::Metric* ComputationClient::EagerExecuteMetric() {
static metrics::Metric* metric =
new metrics::Metric("EagerOpExecuteTime", metrics::MetricFnTime);
return metric;
}

metrics::Metric* ComputationClient::ExecuteReplicatedMetric() {
static metrics::Metric* metric =
new metrics::Metric("ExecuteReplicatedTime", metrics::MetricFnTime);
Expand Down
25 changes: 14 additions & 11 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class XlaCoordinator;
// ComputationClient.
struct ClientExecuteOptions {
bool explode_tuple{true};
bool eager_mode{false};
};

class ComputationClient {
Expand Down Expand Up @@ -213,16 +214,14 @@ class ComputationClient {
// of torch::lazy::Computation?
struct CompileInstance {
CompileInstance() = default;
CompileInstance(xla::XlaComputation computation,
std::string compilation_device,
std::vector<std::string> devices,
const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false,
bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true,
bool use_auto_spmd_partitioning = false,
std::vector<int64_t> auto_spmd_mesh_shape = {},
std::vector<int64_t> auto_spmd_mesh_ids = {})
CompileInstance(
xla::XlaComputation computation, std::string compilation_device,
std::vector<std::string> devices, const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false, bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true,
bool use_auto_spmd_partitioning = false,
std::vector<int64_t> auto_spmd_mesh_shape = {},
std::vector<int64_t> auto_spmd_mesh_ids = {}, bool eager_mode = false)
: computation(std::move(computation)),
compilation_device(std::move(compilation_device)),
devices(std::move(devices)),
Expand All @@ -233,7 +232,8 @@ class ComputationClient {
allow_spmd_sharding_propagation_to_output),
use_auto_spmd_partitioning(use_auto_spmd_partitioning),
auto_spmd_mesh_shape(auto_spmd_mesh_shape),
auto_spmd_mesh_ids(auto_spmd_mesh_ids) {}
auto_spmd_mesh_ids(auto_spmd_mesh_ids),
eager_mode(eager_mode) {}

xla::XlaComputation computation;
std::string compilation_device;
Expand All @@ -245,6 +245,7 @@ class ComputationClient {
bool use_auto_spmd_partitioning;
std::vector<int64_t> auto_spmd_mesh_shape;
std::vector<int64_t> auto_spmd_mesh_ids;
bool eager_mode;
};

struct ExecuteComputationOptions : public ClientExecuteOptions {};
Expand Down Expand Up @@ -430,7 +431,9 @@ class ComputationClient {
static metrics::Metric* TransferToDeviceTransformMetric();
static metrics::Metric* TransferFromDeviceMetric();
static metrics::Metric* CompileMetric();
static metrics::Metric* EagerCompileMetric();
static metrics::Metric* ExecuteMetric();
static metrics::Metric* EagerExecuteMetric();
static metrics::Metric* ExecuteReplicatedMetric();
static metrics::Metric* ExecuteParallelMetric();
static metrics::Metric* ExecuteChainedMetric();
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(

std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
std::vector<ComputationClient::CompileInstance> instances) {
metrics::TimedSection timed(CompileMetric());
auto metrics_fn = CompileMetric;
if (instances[0].eager_mode) {
metrics_fn = EagerCompileMetric;
}
metrics::TimedSection timed(metrics_fn());
tsl::profiler::TraceMe activity("PjRtComputationClient::Compile",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::ComputationPtr> computations;
Expand Down Expand Up @@ -695,7 +699,11 @@ PjRtComputationClient::ExecuteComputation(
// Shared ownership of the timed section ensures that it will only get logged
// once both `ExecuteComputation` and the async work in `ExecuteSharded` are
// complete; a copy is held from the lambda that releases it when done.
auto timed = std::make_shared<metrics::TimedSection>(ExecuteMetric());
auto metrics_fn = ExecuteMetric;
if (options.eager_mode) {
metrics_fn = EagerExecuteMetric;
}
auto timed = std::make_shared<metrics::TimedSection>(metrics_fn());
tsl::profiler::TraceMe activity("PjRtComputationClient::ExecuteComputation",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(1) << "Executing PjRt computation on " << device;
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ XLATensorPtr XLATensor::Create(
std::optional<at::ScalarType> logical_element_type) {
XLATensorPtr xtensor = c10::make_intrusive<XLATensor>(
XLATensor(std::move(ir_value), device, logical_element_type));
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
if (UseEagerDebugMode()) {
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
graph_executor->RegisterTensor(xtensor->data());
if (UseEagerDebugMode() || graph_executor->UseEagerMode()) {
std::vector<XLATensorPtr> xtensors({xtensor});
XLAGraphExecutor::Get()->ApplyEagerSync(xtensors);
graph_executor->ApplyEagerSync(xtensors);
}
return xtensor;
}
Expand Down
18 changes: 12 additions & 6 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) {

XLAGraphExecutor::ComputationCache* CreateComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 2048);
static const bool readonlyPersistentCache =
runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false);
static std::string persistentCacheDir =
Expand Down Expand Up @@ -814,6 +814,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(

std::vector<torch::lazy::BackendDataPtr> results;
if (async->cached_computation->is_sharded) {
// TODO(JackCaoG): handle eager mode
std::vector<std::string> devices =
runtime::GetComputationClient()->GetLocalDevices();
runtime::ComputationClient::ExecuteReplicatedOptions execute_options;
Expand Down Expand Up @@ -1088,7 +1089,8 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
std::shared_ptr<XLAGraphExecutor::Async> async = std::make_shared<Async>(
coll, std::move(parameters_data), std::move(tensors_data),
std::move(cached_computation));
auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs]() {
auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs,
use_eager_mode = UseEagerMode()]() {
try {
std::vector<torch::lazy::BackendDataPtr> results;
// Execute replicated if the compiled computation is partitioned.
Expand Down Expand Up @@ -1117,9 +1119,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
results = torch::lazy::getBackend()->ExecuteComputation(
async->cached_computation->computation, async->parameters_data,
async->device);
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClient()->ExecuteComputation(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), async->device.toString(),
{/*explode_tuple=*/true,
/*eager_mode=*/use_eager_mode});
results = WrapXlaData(outputs);
TORCH_LAZY_COUNTER("ExecuteComputation", 1);
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
Expand Down Expand Up @@ -1372,7 +1378,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
runtime::GetComputationClient()->GetCompilationDevices(
coll.device.toString(), devices),
&shape, should_wrap_parameter, is_sharded});

instances.front().eager_mode = UseEagerMode();
if (use_autosharding) {
TF_VLOG(5) << "use_auto_spmd_partitioning is set.";
TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode.";
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
void ClearPendingIrs(std::vector<XLATensorPtr> tensors,
const torch::lazy::BackendDevice& device);

void SetUseEagerMode(bool use_eager_mode) {
use_eager_mode_ = use_eager_mode;
}

bool UseEagerMode() { return use_eager_mode_; }

private:
// This is just to group results from compile(). Since our computation is
// different, we don't reuse the upstream CompilationResult.
Expand Down Expand Up @@ -361,6 +367,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const SyncTensorsConfig& config, bool warm_up_cache_only = false);

ComputationCache* computation_cache_;
bool use_eager_mode_ = false;
};

} // namespace torch_xla
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .eager import eager_mode

__all__ = [
"eager_mode",
]
9 changes: 9 additions & 0 deletions torch_xla/experimental/eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch_xla


def eager_mode(enable: bool):
"""Configure torch_xla's default executation mode.
Under eager mode only functions that was `torch_xla.compile`d will be
traced and compiled. Other torch ops will be executed eagerly.
"""
torch_xla._XLAC._set_use_eager_mode(enable)

0 comments on commit 7714e78

Please sign in to comment.