From b576d85f079632bc453bf1309175f082b49b5be5 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Fri, 8 Dec 2023 10:46:52 -0800 Subject: [PATCH] [Cherry-Pick] Incorporate compilation environment in hash (#6050) --- torch_xla/csrc/BUILD | 1 + torch_xla/csrc/runtime/BUILD | 22 ++++ torch_xla/csrc/runtime/computation_client.h | 3 + torch_xla/csrc/runtime/env_hash.cc | 102 ++++++++++++++++++ torch_xla/csrc/runtime/env_hash.h | 15 +++ torch_xla/csrc/runtime/env_hash_test.cc | 52 +++++++++ .../csrc/runtime/pjrt_computation_client.cc | 42 ++++++++ .../csrc/runtime/pjrt_computation_client.h | 3 + torch_xla/csrc/xla_graph_executor.cpp | 7 ++ 9 files changed, 247 insertions(+) create mode 100644 torch_xla/csrc/runtime/env_hash.cc create mode 100644 torch_xla/csrc/runtime/env_hash.h create mode 100644 torch_xla/csrc/runtime/env_hash_test.cc diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 1474fa9df98..f5bc1147863 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -121,6 +121,7 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", + ":version", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", "//torch_xla/csrc/runtime:xla_util", diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index f14b757f265..9d58adaa944 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -81,6 +81,7 @@ cc_library( deps = [ ":computation_client", ":debug_macros", + ":env_hash", ":env_vars", ":operation_manager", ":profiler", @@ -143,6 +144,27 @@ cc_library( hdrs = ["env_vars.h"], ) +cc_library( + name = "env_hash", + srcs = ["env_hash.cc"], + hdrs = ["env_hash.h"], + deps = [ + ":sys_util", + "@torch//:headers", + ], +) + +cc_test( + name = "env_hash_test", + size = "small", + srcs = ["env_hash_test.cc"], + deps = [ + ":env_hash", + "@torch//:libtorch_cpu", # For torch::lazy::hash + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "metrics_analysis", srcs = ["metrics_analysis.cc"], diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 9d0b239f212..b97e90ffc2e 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -281,6 +281,9 @@ class ComputationClient { virtual std::vector Compile( std::vector instances) = 0; + // Returns a hash of the current compilation environment. + virtual torch::lazy::hash_t HashCompilationEnv() = 0; + // Executes computation with arguments and returns the result. // The passed device must match the common device of the arguments Data. // If options.explode_tuple is true, the output tuple will be decomposed into diff --git a/torch_xla/csrc/runtime/env_hash.cc b/torch_xla/csrc/runtime/env_hash.cc new file mode 100644 index 00000000000..f0bb6990e11 --- /dev/null +++ b/torch_xla/csrc/runtime/env_hash.cc @@ -0,0 +1,102 @@ +#include "torch_xla/csrc/runtime/env_hash.h" + +#include +#include +#include + +#include "torch_xla/csrc/runtime/sys_util.h" + +namespace torch_xla { +namespace runtime { +namespace hash { + +namespace { +static const std::string XLA_FLAG_PREFIX = "--xla"; + +// Taken from JAX: +// https://github.com/google/jax/blob/8ee5811/jax/_src/cache_key.py#L325-L346 +static const std::unordered_set FlagsToExclude = { + "--xla_dump_compress_protos", + "--xla_dump_module_metadata", + "--xla_dump_max_hlo_modules", + "--xla_dump_include_timestamp", + "--xla_dump_hlo_pass_re", + "--xla_dump_hlo_module_re", + "--xla_dump_hlo_snapshots", + "--xla_dump_fusion_visualization", + "--xla_dump_hlo_as_url", + "--xla_dump_hlo_as_proto", + "--xla_dump_hlo_as_text", + "--xla_dump_hlo_as_long_text", + "--xla_dump_hlo_as_html", + "--xla_dump_hlo_as_dot", + "--xla_dump_to", + "--xla_force_host_platform_device_count", + "--xla_dump_disable_metadata", + "--xla_dump_hlo_pipeline_re", + "--xla_tpu_sdc_checker_streamz_metric", + "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", +}; + +torch::lazy::hash_t hash_xla_flags(std::string env_var_name) { + std::stringstream xla_flag_env( + sys_util::GetEnvString(env_var_name.c_str(), "")); + std::string current_flag; + std::vector xla_flags; + torch::lazy::hash_t hash = 0; + // Parse the space-delimited flags once at a time. + while (std::getline(xla_flag_env, current_flag, ' ')) { + if (current_flag.rfind(XLA_FLAG_PREFIX, 0) != 0) { + continue; + } + // XLA flags require key and value to be separated by '='. + int eq_pos = current_flag.find('='); + std::string flag_key; + if (eq_pos == std::string::npos) { + flag_key = current_flag; + } else { + flag_key = current_flag.substr(0, eq_pos); + } + if (FlagsToExclude.find(flag_key) != FlagsToExclude.end()) { + continue; + } + xla_flags.push_back(current_flag); + } + // Ensure the flags are sorted so the input order doesn't impact the hash. + std::sort(xla_flags.begin(), xla_flags.end()); + for (auto& flag : xla_flags) { + hash = + torch::lazy::HashCombine(hash, torch::lazy::StringHash(flag.c_str())); + } + return hash; +} + +torch::lazy::hash_t hash_xla_env_vars(std::vector flag_vars, + std::vector raw_vars) { + torch::lazy::hash_t hash; + // Parse the flag_vars for XLA flags. + for (auto& env_var_name : flag_vars) { + hash = torch::lazy::HashCombine(hash, hash_xla_flags(env_var_name)); + } + + // Include the raw flag value for raw_vars + for (auto& env_var_name : raw_vars) { + std::string raw_val = sys_util::GetEnvString(env_var_name.c_str(), ""); + hash = torch::lazy::HashCombine(hash, + torch::lazy::StringHash(raw_val.c_str())); + } + return hash; +} +} // namespace + +torch::lazy::hash_t HashXlaEnvVars() { + // Both XLA_FLAGS and LIBTPU_INIT_ARGS contain XLA flags which impact + // the compilation result. + static std::vector flag_vars = {"XLA_FLAGS", "LIBTPU_INIT_ARGS"}; + static std::vector raw_vars = {"TPU_MEGACORE"}; + return hash_xla_env_vars(flag_vars, raw_vars); +} + +} // namespace hash +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/env_hash.h b/torch_xla/csrc/runtime/env_hash.h new file mode 100644 index 00000000000..1d735da9678 --- /dev/null +++ b/torch_xla/csrc/runtime/env_hash.h @@ -0,0 +1,15 @@ +#include + +namespace torch_xla { +namespace runtime { +namespace hash { + +// Take a hash of XLA flags which impact the compilation result. +// TODO(jonbolin): We should move away from manually hashing the env vars and +// instead hash the compilation environment directly when the functionality is +// supported in the runtime. +torch::lazy::hash_t HashXlaEnvVars(); + +} // namespace hash +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/env_hash_test.cc b/torch_xla/csrc/runtime/env_hash_test.cc new file mode 100644 index 00000000000..68212c72012 --- /dev/null +++ b/torch_xla/csrc/runtime/env_hash_test.cc @@ -0,0 +1,52 @@ +#include "torch_xla/csrc/runtime/env_hash.h" + +#include +#include + +#include + +namespace torch_xla { +namespace runtime { +namespace hash { + +TEST(HashTest, CompilationEnvHashTest) { + for (const char* flag_var : {"XLA_FLAGS", "LIBTPU_INIT_ARGS"}) { + torch::lazy::hash_t base_hash = HashXlaEnvVars(); + + // Add an ignored XLA flag to the environment + setenv(flag_var, "--xla_dump_to=/foo/bar", /*overwrite=*/true); + EXPECT_TRUE(base_hash == HashXlaEnvVars()); + + // Add some non-ignored XLA flag to the environment + setenv(flag_var, "--xla_foo_bar=1 --xla_bar_baz=0", /*overwrite=*/true); + torch::lazy::hash_t nonignored_xla_flag = HashXlaEnvVars(); + EXPECT_TRUE(base_hash != nonignored_xla_flag); + + // Add an ignored XLA flag in addition to the non-ignored + setenv(flag_var, "--xla_foo_bar=1 --xla_bar_baz=0 --xla_dump_to=/foo/bar", + /*overwrite=*/true); + torch::lazy::hash_t mixed_xla_flag = HashXlaEnvVars(); + EXPECT_TRUE(nonignored_xla_flag == mixed_xla_flag); + + // Reordering the XLA flags should not impact the hash + setenv(flag_var, "--xla_bar_baz=0 --xla_dump_to=/foo/bar --xla_foo_bar=1", + /*overwrite=*/true); + torch::lazy::hash_t mixed_reordered_xla_flag = HashXlaEnvVars(); + EXPECT_TRUE(mixed_xla_flag == mixed_reordered_xla_flag); + + // Changing the XLA flag value should impact the hash + setenv(flag_var, "--xla_bar_baz=1 --xla_dump_to=/foo/bar --xla_foo_bar=1", + /*overwrite=*/true); + torch::lazy::hash_t new_value_xla_flag = HashXlaEnvVars(); + EXPECT_TRUE(mixed_reordered_xla_flag != new_value_xla_flag); + } + + // Modifying the value of TPU_MEGACORE should impact the hash + torch::lazy::hash_t base_hash = HashXlaEnvVars(); + setenv("TPU_MEGACORE", "megacore", /*overwrite=*/true); + EXPECT_TRUE(base_hash != HashXlaEnvVars()); +} + +} // namespace hash +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index bf3bccd210e..a53ad32a943 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -11,6 +11,7 @@ #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/env_hash.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/operation_manager.h" #include "torch_xla/csrc/runtime/profiler.h" @@ -84,6 +85,40 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } +torch::lazy::hash_t hash_comp_env( + std::shared_ptr client, + std::vector& ordered_devices) { + torch::lazy::hash_t hash = hash::HashXlaEnvVars(); + // Whether or not SPMD mode is active should influence the hash. + hash = torch::lazy::HashCombine(hash, UseVirtualDevice()); + auto topology_desc = client->GetTopologyDescription(); + if (topology_desc.ok()) { + // Some backends support a topology description which provides a better + // view of the specific compilation environment. + auto serialized = topology_desc.value()->Serialize(); + if (serialized.ok()) { + return torch::lazy::HashCombine( + hash, + torch::lazy::DataHash(serialized->data(), serialized->length())); + } + // If serialization fails, fallthrough to the manual approach. + } + std::string platform_name(client->platform_name()); + std::string platform_version(client->platform_version()); + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(platform_name.c_str())); + // platform_version incorporates libtpu version and hardware type. + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(platform_version.c_str())); + // Include global devices in the hash, ensuring order is consistent. + for (auto& device : ordered_devices) { + std::string device_str(device->ToString()); + hash = torch::lazy::HashCombine( + hash, torch::lazy::StringHash(device_str.c_str())); + } + return hash; +} + } // namespace std::string PjRtComputationClient::PjRtDeviceToString( @@ -563,6 +598,13 @@ std::vector PjRtComputationClient::Compile( return computations; } +torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { + // TODO(jonbolin): Incorporate CompileOptions into the hash. These are + // deterministically generated at the moment, so they don't need to be + // included. It will require a small refactor, so punting on this for now. + return comp_env_hash_; +} + std::vector PjRtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 93760ee6b05..53b16a79ca5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -72,6 +72,8 @@ class PjRtComputationClient : public ComputationClient { std::vector GetAllDevices() const override; + torch::lazy::hash_t HashCompilationEnv() override; + int GetProcessIndex() const override { return client_->process_index(); }; int GetNumProcesses() const override; @@ -114,6 +116,7 @@ class PjRtComputationClient : public ComputationClient { OperationManager operation_manager_; tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); + torch::lazy::hash_t comp_env_hash_; xla::PjRtDevice* StringToPjRtDevice(const std::string& device); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 236d36e7e57..c19f776ea22 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -53,6 +53,7 @@ #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" +#include "torch_xla/csrc/version.h" #include "torch_xla/csrc/xla_backend_impl.h" #include "torch_xla/csrc/xla_sharding_util.h" #include "tsl/platform/errors.h" @@ -536,6 +537,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // The force_ltc_data controls aliasing compilation, so effectively the same // graph with on/off force_ltc_data should not match, hash wise. coll.hash = torch::lazy::MHash(config.force_ltc_data); + // Ensure the compilation environment and git revision are reflected in the + // hash. + coll.hash = torch::lazy::HashCombine( + coll.hash, runtime::GetComputationClient()->HashCompilationEnv()); + coll.hash = + torch::lazy::HashCombine(coll.hash, torch::lazy::StringHash(XLA_GITREV)); coll.config = config; coll.device = *unique_device; coll.indices.reserve(tensors.size());