Skip to content

Commit

Permalink
[Cherry-Pick] Incorporate compilation environment in hash (#6050)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Dec 8, 2023
1 parent aaccf54 commit b576d85
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 0 deletions.
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ ptxla_cc_library(
":layout_manager",
":shape_builder",
":shape_helper",
":version",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:stablehlo_helper",
"//torch_xla/csrc/runtime:xla_util",
Expand Down
22 changes: 22 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ cc_library(
deps = [
":computation_client",
":debug_macros",
":env_hash",
":env_vars",
":operation_manager",
":profiler",
Expand Down Expand Up @@ -143,6 +144,27 @@ cc_library(
hdrs = ["env_vars.h"],
)

cc_library(
name = "env_hash",
srcs = ["env_hash.cc"],
hdrs = ["env_hash.h"],
deps = [
":sys_util",
"@torch//:headers",
],
)

cc_test(
name = "env_hash_test",
size = "small",
srcs = ["env_hash_test.cc"],
deps = [
":env_hash",
"@torch//:libtorch_cpu", # For torch::lazy::hash
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "metrics_analysis",
srcs = ["metrics_analysis.cc"],
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ class ComputationClient {
virtual std::vector<ComputationPtr> Compile(
std::vector<CompileInstance> instances) = 0;

// Returns a hash of the current compilation environment.
virtual torch::lazy::hash_t HashCompilationEnv() = 0;

// Executes computation with arguments and returns the result.
// The passed device must match the common device of the arguments Data.
// If options.explode_tuple is true, the output tuple will be decomposed into
Expand Down
102 changes: 102 additions & 0 deletions torch_xla/csrc/runtime/env_hash.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "torch_xla/csrc/runtime/env_hash.h"

#include <iostream>
#include <sstream>
#include <unordered_set>

#include "torch_xla/csrc/runtime/sys_util.h"

namespace torch_xla {
namespace runtime {
namespace hash {

namespace {
static const std::string XLA_FLAG_PREFIX = "--xla";

// Taken from JAX:
// https://github.com/google/jax/blob/8ee5811/jax/_src/cache_key.py#L325-L346
static const std::unordered_set<std::string> FlagsToExclude = {
"--xla_dump_compress_protos",
"--xla_dump_module_metadata",
"--xla_dump_max_hlo_modules",
"--xla_dump_include_timestamp",
"--xla_dump_hlo_pass_re",
"--xla_dump_hlo_module_re",
"--xla_dump_hlo_snapshots",
"--xla_dump_fusion_visualization",
"--xla_dump_hlo_as_url",
"--xla_dump_hlo_as_proto",
"--xla_dump_hlo_as_text",
"--xla_dump_hlo_as_long_text",
"--xla_dump_hlo_as_html",
"--xla_dump_hlo_as_dot",
"--xla_dump_to",
"--xla_force_host_platform_device_count",
"--xla_dump_disable_metadata",
"--xla_dump_hlo_pipeline_re",
"--xla_tpu_sdc_checker_streamz_metric",
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
};

torch::lazy::hash_t hash_xla_flags(std::string env_var_name) {
std::stringstream xla_flag_env(
sys_util::GetEnvString(env_var_name.c_str(), ""));
std::string current_flag;
std::vector<std::string> xla_flags;
torch::lazy::hash_t hash = 0;
// Parse the space-delimited flags once at a time.
while (std::getline(xla_flag_env, current_flag, ' ')) {
if (current_flag.rfind(XLA_FLAG_PREFIX, 0) != 0) {
continue;
}
// XLA flags require key and value to be separated by '='.
int eq_pos = current_flag.find('=');
std::string flag_key;
if (eq_pos == std::string::npos) {
flag_key = current_flag;
} else {
flag_key = current_flag.substr(0, eq_pos);
}
if (FlagsToExclude.find(flag_key) != FlagsToExclude.end()) {
continue;
}
xla_flags.push_back(current_flag);
}
// Ensure the flags are sorted so the input order doesn't impact the hash.
std::sort(xla_flags.begin(), xla_flags.end());
for (auto& flag : xla_flags) {
hash =
torch::lazy::HashCombine(hash, torch::lazy::StringHash(flag.c_str()));
}
return hash;
}

torch::lazy::hash_t hash_xla_env_vars(std::vector<std::string> flag_vars,
std::vector<std::string> raw_vars) {
torch::lazy::hash_t hash;
// Parse the flag_vars for XLA flags.
for (auto& env_var_name : flag_vars) {
hash = torch::lazy::HashCombine(hash, hash_xla_flags(env_var_name));
}

// Include the raw flag value for raw_vars
for (auto& env_var_name : raw_vars) {
std::string raw_val = sys_util::GetEnvString(env_var_name.c_str(), "");
hash = torch::lazy::HashCombine(hash,
torch::lazy::StringHash(raw_val.c_str()));
}
return hash;
}
} // namespace

torch::lazy::hash_t HashXlaEnvVars() {
// Both XLA_FLAGS and LIBTPU_INIT_ARGS contain XLA flags which impact
// the compilation result.
static std::vector<std::string> flag_vars = {"XLA_FLAGS", "LIBTPU_INIT_ARGS"};
static std::vector<std::string> raw_vars = {"TPU_MEGACORE"};
return hash_xla_env_vars(flag_vars, raw_vars);
}

} // namespace hash
} // namespace runtime
} // namespace torch_xla
15 changes: 15 additions & 0 deletions torch_xla/csrc/runtime/env_hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <torch/csrc/lazy/core/hash.h>

namespace torch_xla {
namespace runtime {
namespace hash {

// Take a hash of XLA flags which impact the compilation result.
// TODO(jonbolin): We should move away from manually hashing the env vars and
// instead hash the compilation environment directly when the functionality is
// supported in the runtime.
torch::lazy::hash_t HashXlaEnvVars();

} // namespace hash
} // namespace runtime
} // namespace torch_xla
52 changes: 52 additions & 0 deletions torch_xla/csrc/runtime/env_hash_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "torch_xla/csrc/runtime/env_hash.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <cstdlib>

namespace torch_xla {
namespace runtime {
namespace hash {

TEST(HashTest, CompilationEnvHashTest) {
for (const char* flag_var : {"XLA_FLAGS", "LIBTPU_INIT_ARGS"}) {
torch::lazy::hash_t base_hash = HashXlaEnvVars();

// Add an ignored XLA flag to the environment
setenv(flag_var, "--xla_dump_to=/foo/bar", /*overwrite=*/true);
EXPECT_TRUE(base_hash == HashXlaEnvVars());

// Add some non-ignored XLA flag to the environment
setenv(flag_var, "--xla_foo_bar=1 --xla_bar_baz=0", /*overwrite=*/true);
torch::lazy::hash_t nonignored_xla_flag = HashXlaEnvVars();
EXPECT_TRUE(base_hash != nonignored_xla_flag);

// Add an ignored XLA flag in addition to the non-ignored
setenv(flag_var, "--xla_foo_bar=1 --xla_bar_baz=0 --xla_dump_to=/foo/bar",
/*overwrite=*/true);
torch::lazy::hash_t mixed_xla_flag = HashXlaEnvVars();
EXPECT_TRUE(nonignored_xla_flag == mixed_xla_flag);

// Reordering the XLA flags should not impact the hash
setenv(flag_var, "--xla_bar_baz=0 --xla_dump_to=/foo/bar --xla_foo_bar=1",
/*overwrite=*/true);
torch::lazy::hash_t mixed_reordered_xla_flag = HashXlaEnvVars();
EXPECT_TRUE(mixed_xla_flag == mixed_reordered_xla_flag);

// Changing the XLA flag value should impact the hash
setenv(flag_var, "--xla_bar_baz=1 --xla_dump_to=/foo/bar --xla_foo_bar=1",
/*overwrite=*/true);
torch::lazy::hash_t new_value_xla_flag = HashXlaEnvVars();
EXPECT_TRUE(mixed_reordered_xla_flag != new_value_xla_flag);
}

// Modifying the value of TPU_MEGACORE should impact the hash
torch::lazy::hash_t base_hash = HashXlaEnvVars();
setenv("TPU_MEGACORE", "megacore", /*overwrite=*/true);
EXPECT_TRUE(base_hash != HashXlaEnvVars());
}

} // namespace hash
} // namespace runtime
} // namespace torch_xla
42 changes: 42 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_hash.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/operation_manager.h"
#include "torch_xla/csrc/runtime/profiler.h"
Expand Down Expand Up @@ -84,6 +85,40 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
return allocator_config;
}

torch::lazy::hash_t hash_comp_env(
std::shared_ptr<xla::PjRtClient> client,
std::vector<xla::PjRtDevice*>& ordered_devices) {
torch::lazy::hash_t hash = hash::HashXlaEnvVars();
// Whether or not SPMD mode is active should influence the hash.
hash = torch::lazy::HashCombine(hash, UseVirtualDevice());
auto topology_desc = client->GetTopologyDescription();
if (topology_desc.ok()) {
// Some backends support a topology description which provides a better
// view of the specific compilation environment.
auto serialized = topology_desc.value()->Serialize();
if (serialized.ok()) {
return torch::lazy::HashCombine(
hash,
torch::lazy::DataHash(serialized->data(), serialized->length()));
}
// If serialization fails, fallthrough to the manual approach.
}
std::string platform_name(client->platform_name());
std::string platform_version(client->platform_version());
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(platform_name.c_str()));
// platform_version incorporates libtpu version and hardware type.
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(platform_version.c_str()));
// Include global devices in the hash, ensuring order is consistent.
for (auto& device : ordered_devices) {
std::string device_str(device->ToString());
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(device_str.c_str()));
}
return hash;
}

} // namespace

std::string PjRtComputationClient::PjRtDeviceToString(
Expand Down Expand Up @@ -563,6 +598,13 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
return computations;
}

torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() {
// TODO(jonbolin): Incorporate CompileOptions into the hash. These are
// deterministically generated at the moment, so they don't need to be
// included. It will require a small refactor, so punting on this for now.
return comp_env_hash_;
}

std::vector<ComputationClient::DataPtr>
PjRtComputationClient::ExecuteComputation(
const ComputationClient::Computation& computation,
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class PjRtComputationClient : public ComputationClient {

std::vector<std::string> GetAllDevices() const override;

torch::lazy::hash_t HashCompilationEnv() override;

int GetProcessIndex() const override { return client_->process_index(); };

int GetNumProcesses() const override;
Expand Down Expand Up @@ -114,6 +116,7 @@ class PjRtComputationClient : public ComputationClient {
OperationManager operation_manager_;
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(
tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());
torch::lazy::hash_t comp_env_hash_;

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);

Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/thread_pool.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/version.h"
#include "torch_xla/csrc/xla_backend_impl.h"
#include "torch_xla/csrc/xla_sharding_util.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -536,6 +537,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
// The force_ltc_data controls aliasing compilation, so effectively the same
// graph with on/off force_ltc_data should not match, hash wise.
coll.hash = torch::lazy::MHash(config.force_ltc_data);
// Ensure the compilation environment and git revision are reflected in the
// hash.
coll.hash = torch::lazy::HashCombine(
coll.hash, runtime::GetComputationClient()->HashCompilationEnv());
coll.hash =
torch::lazy::HashCombine(coll.hash, torch::lazy::StringHash(XLA_GITREV));
coll.config = config;
coll.device = *unique_device;
coll.indices.reserve(tensors.size());
Expand Down

0 comments on commit b576d85

Please sign in to comment.