Skip to content

Commit

Permalink
[xla:cpu] Add a flag to limit the CPU features that LLVM will codegen.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676110440
  • Loading branch information
penpornk authored and Google-ML-Automation committed Sep 27, 2024
1 parent 0983168 commit 3ccaaab
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 27 deletions.
8 changes: 8 additions & 0 deletions third_party/tsl/tsl/platform/cpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ void InitCPUIDInfo() {

} // namespace

bool IsX86CPU() {
#ifdef PLATFORM_IS_X86
return true;
#else
return false;
#endif
}

bool TestCPUFeature(CPUFeature feature) {
#ifdef PLATFORM_IS_X86
return CPUIDInfo::TestFeature(feature);
Expand Down
3 changes: 3 additions & 0 deletions third_party/tsl/tsl/platform/cpu_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ bool TestAarch64CPU(Aarch64CPU cpu);
// Checks CPU registers to return hardware capabilities.
bool TestCPUFeature(CPUFeature feature);

// Checks whether the current processor is x86.
bool IsX86CPU();

// Returns CPU Vendor string (i.e. 'GenuineIntel', 'AuthenticAMD', etc.)
std::string CPUVendorIDString();

Expand Down
15 changes: 15 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_parallel_codegen_split_count(32);
opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false);
opts.set_xla_cpu_prefer_vector_width(256);
opts.set_xla_cpu_max_isa("");

opts.set_xla_cpu_enable_fast_math(false);
// Disable forms of fast math that have caused users problems in the past.
Expand Down Expand Up @@ -379,6 +380,15 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
};
};

auto uppercase_string_setter_for =
[debug_options](
void (DebugOptions::*member_setter)(const std::string& value)) {
return [debug_options, member_setter](const std::string& value) {
(debug_options->*member_setter)(absl::AsciiStrToUpper(value));
return true;
};
};

auto float_setter_for =
[debug_options](void (DebugOptions::*member_setter)(float)) {
return [debug_options, member_setter](float value) {
Expand Down Expand Up @@ -881,6 +891,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_cpu_prefer_vector_width),
debug_options->xla_cpu_prefer_vector_width(),
"Preferred vector with for the XLA:CPU LLVM backend."));
flag_list->push_back(
tsl::Flag("xla_cpu_max_isa",
uppercase_string_setter_for(&DebugOptions::set_xla_cpu_max_isa),
debug_options->xla_cpu_max_isa(),
"Max ISA to use for the XLA:CPU LLVM backend."));
flag_list->push_back(tsl::Flag(
"xla_gpu_crash_on_verification_failures",
bool_setter_for(
Expand Down
2 changes: 2 additions & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ cc_library(
"//xla:util",
"//xla/service:custom_call_target_registry",
"//xla/service:llvm_compiler",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/memory",
Expand All @@ -513,6 +514,7 @@ cc_library(
"@llvm-project//llvm:TargetParser",
"@llvm-project//mlir:mlir_c_runner_utils",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:platform_port",
] + xla_internal(["service/cpu:named_orc_jit_memory_mapper"]),
)

Expand Down
27 changes: 24 additions & 3 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,16 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
// parallel compilation at run time.
size_t parallel_codegen_split_count =
debug_options.xla_cpu_parallel_codegen_split_count();
std::string max_cpu_isa = debug_options.xla_cpu_max_isa();
if (VLOG_IS_ON(1)) {
if (tsl::port::IsX86CPU()) {
VLOG(1) << "`xla_cpu_max_isa` is set. Will not use features newer than: "
<< max_cpu_isa;
} else {
VLOG(1) << "`xla_cpu_max_isa` is set to `" << max_cpu_isa
<< "`. This flag is not supported on non-x86 CPUs yet.";
}
}

auto jit = SimpleOrcJIT::Create(
CompilerTargetOptions(module->config()),
Expand All @@ -1341,7 +1351,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
post_optimization_ir_hook,
CreateOrcJITPostCompilationHook(module.get(), &obj_files),
parallel_codegen_split_count);
parallel_codegen_split_count, max_cpu_isa);
if (!jit) {
return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError()));
}
Expand Down Expand Up @@ -2033,15 +2043,26 @@ CpuExecutableAotCompilationResult::LoadExecutable(
compiler->BufferSizeBytesFunction(),
/*can_share_buffer=*/nullptr));

const DebugOptions& debug_options = module->config().debug_options();
std::string max_cpu_isa = debug_options.xla_cpu_max_isa();
if (VLOG_IS_ON(1)) {
if (tsl::port::IsX86CPU()) {
VLOG(1) << "`xla_cpu_max_isa` is set. Will not use features newer than: "
<< max_cpu_isa;
} else {
VLOG(1) << "`xla_cpu_max_isa` is set to `" << max_cpu_isa
<< "`. This flag is not supported on non-x86 CPUs yet.";
}
}
auto jit = SimpleOrcJIT::Create(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
debug_options.xla_llvm_disable_expensive_passes(),
options::SlpVectorizerDisabled(module->config()),
llvm_ir::GetCpuFastMathFlags(module->config()),
/*pre_optimization_hook=*/nullptr, /*post_optimization_hook=*/nullptr,
/*post_codegen_hook=*/nullptr);
/*post_codegen_hook=*/nullptr, /*num_jit_dylibs=*/1, max_cpu_isa);
if (!jit) {
return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError()));
}
Expand Down
156 changes: 136 additions & 20 deletions xla/service/cpu/simple_orc_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ limitations under the License.
#include <cstdlib>
#include <cstring>
#include <memory>
#include <optional>
#include <string>
#include <system_error> // NOLINT
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -81,6 +83,7 @@ limitations under the License.
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/llvm_compiler.h"
#include "xla/util.h"
#include "tsl/platform/cpu_info.h"
#include "tsl/platform/logging.h"

#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3)
Expand All @@ -97,15 +100,6 @@ extern "C" uint16_t __truncsfbf2(float);
extern "C" uint16_t __truncdfbf2(double);

namespace xla::cpu {

std::vector<std::string> DetectMachineAttributes() {
std::vector<std::string> result;
for (const auto& [feature, enabled] : llvm::sys::getHostCPUFeatures()) {
result.push_back((enabled ? '+' : '-') + std::string(feature));
}
return result;
}

namespace {

class DefaultMemoryMapper final
Expand Down Expand Up @@ -302,30 +296,152 @@ bool ContiguousSectionMemoryManager::finalizeMemory(std::string* err_msg) {
return false;
}

using tsl::port::CPUFeature;

// Returns the earliest CPU generation that supports the instruction set.
llvm::StringRef CPUTargetFromMaxFeature(CPUFeature max_feature) {
switch (max_feature) {
case CPUFeature::SSE4_2:
return "nehalem";
case CPUFeature::AVX:
return "sandybridge";
case CPUFeature::AVX2:
return "haswell";
case CPUFeature::AVX512F:
return "skylake-avx512";
case CPUFeature::AVX512_VNNI:
return "cascadelake";
case CPUFeature::AVX512_BF16:
return "cooperlake";
case CPUFeature::AMX_BF16:
case CPUFeature::AMX_INT8:
return "sapphirerapids";
case CPUFeature::AMX_FP16:
return "graniterapids";
default:
LOG(FATAL) << "Unsupported max feature: " << max_feature;
}
}

} // namespace

std::optional<CPUFeature> ISAStringToFeature(const std::string feature_string) {
if (feature_string.empty()) return std::nullopt;

// Non-exhaustive list of CPU features. (Only the ones we care about.)
// TODO(penporn): Handle ARM
static absl::flat_hash_map<std::string, CPUFeature>* x86 = [] {
return new absl::flat_hash_map<std::string, CPUFeature>(
{{"SSE4_2", CPUFeature::SSE4_2},
{"AVX", CPUFeature::AVX},
{"AVX2", CPUFeature::AVX2},
{"AVX512", CPUFeature::AVX512F},
{"AVX512_VNNI", CPUFeature::AVX512_VNNI},
{"AVX512_BF16", CPUFeature::AVX512_BF16},
{"AMX", CPUFeature::AMX_BF16}, // Includes AMX_INT8.
{"AMX_FP16", CPUFeature::AMX_FP16}});
}();

// Assume that `feature_string` always contains all uppercase letters.
if (auto it = x86->find(feature_string); it != x86->end()) return it->second;
LOG(WARNING) << "Unknown CPU ISA: " << feature_string;
return std::nullopt;
}

// Disable any feature that is newer than `max_feature`.
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
bool ShouldEnableCPUFeature(const llvm::StringRef feature,
const CPUFeature& max_feature) {
// TODO(penporn): Figure out where to put AVX10.
switch (max_feature) {
case CPUFeature::SSE4_2:
if (feature.starts_with("avx") || feature == "f16c" ||
feature == "vpclmulqdq" || feature == "vaes") {
return false;
}
[[fallthrough]];
case CPUFeature::AVX:
if (feature.starts_with("avx2") || feature.starts_with("fma")) {
return false;
}
[[fallthrough]];
case CPUFeature::AVX2:
if (feature.starts_with("avx512") || feature == "evex512") return false;
[[fallthrough]];
case CPUFeature::AVX512F:
if (feature == "avx512vnni") return false;
[[fallthrough]];
case CPUFeature::AVX512_VNNI:
if (feature == "avx512bf16") return false;
[[fallthrough]];
case CPUFeature::AVX512_BF16:
if (feature.starts_with("amx")) return false;
[[fallthrough]];
case CPUFeature::AMX_INT8:
case CPUFeature::AMX_BF16:
if (feature == "amx-fp16") return false;
[[fallthrough]];
default:
// Leave all other features enabled.
return true;
}
}

std::vector<std::string> DetectMachineAttributes(
std::optional<CPUFeature> max_feature, bool& features_filtered) {
std::vector<std::string> result;
features_filtered = false;
// We only have x86 constraints. Skip the check if we are on non-x86 CPUs.
const bool no_feature_constraint =
!max_feature.has_value() || !tsl::port::IsX86CPU();
for (const auto& [feature, enabled] : llvm::sys::getHostCPUFeatures()) {
bool should_enable =
enabled && (no_feature_constraint ||
ShouldEnableCPUFeature(feature, *max_feature));
result.push_back((should_enable ? '+' : '-') + std::string(feature));
features_filtered |= (should_enable != enabled);
}
return result;
}

std::vector<std::string> DetectMachineAttributes() {
bool features_filtered = false;
return DetectMachineAttributes(std::nullopt, features_filtered);
}

/*static*/ std::unique_ptr<llvm::TargetMachine>
SimpleOrcJIT::InferTargetMachineForJIT(
const llvm::TargetOptions& target_options,
llvm::CodeGenOptLevel opt_level) {
std::vector<std::string> attrs = DetectMachineAttributes();
const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level,
std::string max_cpu_isa) {
bool features_filtered = false;
std::optional<CPUFeature> max_feature = ISAStringToFeature(max_cpu_isa);
std::vector<std::string> attrs =
DetectMachineAttributes(max_feature, features_filtered);
llvm::SmallVector<std::string, 0> llvm_attrs(attrs.begin(), attrs.end());
// If `max_feature` is newer than the host CPU, we should keep the host CPU
// name, e.g., we don't want to set the target CPU to Skylake when we are on
// a Broadwell host.
llvm::StringRef target_cpu = features_filtered
? CPUTargetFromMaxFeature(*max_feature)
: llvm::sys::getHostCPUName();
std::unique_ptr<llvm::TargetMachine> target_machine(
llvm::EngineBuilder()
.setTargetOptions(target_options)
.setOptLevel(opt_level)
.selectTarget(
/*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
/*MCPU=*/llvm::sys::getHostCPUName(),
/*MCPU=*/target_cpu,
/*MAttrs=*/llvm_attrs));
CHECK(target_machine != nullptr);
return target_machine;
}

static CompilerFunctor::TargetMachineBuilder CreateTargetMachineBuilder(
llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level) {
return [target_options, opt_level]() {
return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level);
llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level,
std::string max_cpu_isa) {
return [target_options, opt_level, max_cpu_isa]() {
return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level,
max_cpu_isa);
};
}

Expand All @@ -338,9 +454,9 @@ SimpleOrcJIT::SimpleOrcJIT(
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook,
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook,
size_t num_jit_dylibs)
size_t num_jit_dylibs, std::string max_cpu_isa)
: target_machine_builder_(
CreateTargetMachineBuilder(target_options, opt_level)),
CreateTargetMachineBuilder(target_options, opt_level, max_cpu_isa)),
target_machine_(target_machine_builder_()),
target_triple_(target_machine_->getTargetTriple()),
data_layout_(target_machine_->createDataLayout()),
Expand Down Expand Up @@ -426,7 +542,7 @@ llvm::Expected<std::unique_ptr<SimpleOrcJIT>> SimpleOrcJIT::Create(
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook,
absl::AnyInvocable<void(const llvm::object::ObjectFile&)> post_codegen_hook,
size_t num_jit_dylibs) {
size_t num_jit_dylibs, std::string max_cpu_isa) {
auto SSP = std::make_shared<llvm::orc::SymbolStringPool>();
auto target_process_control =
llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP));
Expand All @@ -441,7 +557,7 @@ llvm::Expected<std::unique_ptr<SimpleOrcJIT>> SimpleOrcJIT::Create(
target_options, opt_level, optimize_for_size, disable_expensive_passes,
disable_slp_vectorizer, fast_math_flags, std::move(pre_optimization_hook),
std::move(post_optimization_hook), std::move(post_codegen_hook),
num_jit_dylibs);
num_jit_dylibs, std::move(max_cpu_isa));
}

llvm::orc::ExecutorSymbolDef SimpleOrcJIT::ResolveRuntimeSymbol(
Expand Down
Loading

0 comments on commit 3ccaaab

Please sign in to comment.