diff --git a/third_party/tsl/tsl/platform/cpu_info.cc b/third_party/tsl/tsl/platform/cpu_info.cc index 1de5eb8031623..36d506f6846bf 100644 --- a/third_party/tsl/tsl/platform/cpu_info.cc +++ b/third_party/tsl/tsl/platform/cpu_info.cc @@ -490,6 +490,14 @@ void InitCPUIDInfo() { } // namespace +bool IsX86CPU() { +#ifdef PLATFORM_IS_X86 + return true; +#else + return false; +#endif +} + bool TestCPUFeature(CPUFeature feature) { #ifdef PLATFORM_IS_X86 return CPUIDInfo::TestFeature(feature); diff --git a/third_party/tsl/tsl/platform/cpu_info.h b/third_party/tsl/tsl/platform/cpu_info.h index 68506b1d34ae8..98df637f98698 100644 --- a/third_party/tsl/tsl/platform/cpu_info.h +++ b/third_party/tsl/tsl/platform/cpu_info.h @@ -150,6 +150,9 @@ bool TestAarch64CPU(Aarch64CPU cpu); // Checks CPU registers to return hardware capabilities. bool TestCPUFeature(CPUFeature feature); +// Checks whether the current processor is x86. +bool IsX86CPU(); + // Returns CPU Vendor string (i.e. 'GenuineIntel', 'AuthenticAMD', etc.) std::string CPUVendorIDString(); diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index abedfc370dd83..4bdf0d2668ab2 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -87,6 +87,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_parallel_codegen_split_count(32); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); opts.set_xla_cpu_prefer_vector_width(256); + opts.set_xla_cpu_max_isa(""); opts.set_xla_cpu_enable_fast_math(false); // Disable forms of fast math that have caused users problems in the past. @@ -379,6 +380,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, }; }; + auto uppercase_string_setter_for = + [debug_options]( + void (DebugOptions::*member_setter)(const std::string& value)) { + return [debug_options, member_setter](const std::string& value) { + (debug_options->*member_setter)(absl::AsciiStrToUpper(value)); + return true; + }; + }; + auto float_setter_for = [debug_options](void (DebugOptions::*member_setter)(float)) { return [debug_options, member_setter](float value) { @@ -881,6 +891,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for(&DebugOptions::set_xla_cpu_prefer_vector_width), debug_options->xla_cpu_prefer_vector_width(), "Preferred vector with for the XLA:CPU LLVM backend.")); + flag_list->push_back( + tsl::Flag("xla_cpu_max_isa", + uppercase_string_setter_for(&DebugOptions::set_xla_cpu_max_isa), + debug_options->xla_cpu_max_isa(), + "Max ISA to use for the XLA:CPU LLVM backend.")); flag_list->push_back(tsl::Flag( "xla_gpu_crash_on_verification_failures", bool_setter_for( diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 301455b651e68..d3420cf0491b3 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -498,6 +498,7 @@ cc_library( "//xla:util", "//xla/service:custom_call_target_registry", "//xla/service:llvm_compiler", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", @@ -513,6 +514,7 @@ cc_library( "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:mlir_c_runner_utils", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", ] + xla_internal(["service/cpu:named_orc_jit_memory_mapper"]), ) diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index ee97d2afa0240..4dd0cbc4fe9a1 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -1331,6 +1331,16 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // parallel compilation at run time. size_t parallel_codegen_split_count = debug_options.xla_cpu_parallel_codegen_split_count(); + std::string max_cpu_isa = debug_options.xla_cpu_max_isa(); + if (VLOG_IS_ON(1)) { + if (tsl::port::IsX86CPU()) { + VLOG(1) << "`xla_cpu_max_isa` is set. Will not use features newer than: " + << max_cpu_isa; + } else { + VLOG(1) << "`xla_cpu_max_isa` is set to `" << max_cpu_isa + << "`. This flag is not supported on non-x86 CPUs yet."; + } + } auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), @@ -1341,7 +1351,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, CreateOrcJITPostCompilationHook(module.get(), &obj_files), - parallel_codegen_split_count); + parallel_codegen_split_count, max_cpu_isa); if (!jit) { return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } @@ -2033,15 +2043,26 @@ CpuExecutableAotCompilationResult::LoadExecutable( compiler->BufferSizeBytesFunction(), /*can_share_buffer=*/nullptr)); + const DebugOptions& debug_options = module->config().debug_options(); + std::string max_cpu_isa = debug_options.xla_cpu_max_isa(); + if (VLOG_IS_ON(1)) { + if (tsl::port::IsX86CPU()) { + VLOG(1) << "`xla_cpu_max_isa` is set. Will not use features newer than: " + << max_cpu_isa; + } else { + VLOG(1) << "`xla_cpu_max_isa` is set to `" << max_cpu_isa + << "`. This flag is not supported on non-x86 CPUs yet."; + } + } auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_llvm_disable_expensive_passes(), + debug_options.xla_llvm_disable_expensive_passes(), options::SlpVectorizerDisabled(module->config()), llvm_ir::GetCpuFastMathFlags(module->config()), /*pre_optimization_hook=*/nullptr, /*post_optimization_hook=*/nullptr, - /*post_codegen_hook=*/nullptr); + /*post_codegen_hook=*/nullptr, /*num_jit_dylibs=*/1, max_cpu_isa); if (!jit) { return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } diff --git a/xla/service/cpu/simple_orc_jit.cc b/xla/service/cpu/simple_orc_jit.cc index 0e9a2feb18609..5c0a8494633de 100644 --- a/xla/service/cpu/simple_orc_jit.cc +++ b/xla/service/cpu/simple_orc_jit.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include #include +#include #include #include // NOLINT #include #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -81,6 +83,7 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" #include "xla/service/llvm_compiler.h" #include "xla/util.h" +#include "tsl/platform/cpu_info.h" #include "tsl/platform/logging.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) @@ -97,15 +100,6 @@ extern "C" uint16_t __truncsfbf2(float); extern "C" uint16_t __truncdfbf2(double); namespace xla::cpu { - -std::vector DetectMachineAttributes() { - std::vector result; - for (const auto& [feature, enabled] : llvm::sys::getHostCPUFeatures()) { - result.push_back((enabled ? '+' : '-') + std::string(feature)); - } - return result; -} - namespace { class DefaultMemoryMapper final @@ -302,30 +296,152 @@ bool ContiguousSectionMemoryManager::finalizeMemory(std::string* err_msg) { return false; } +using tsl::port::CPUFeature; + +// Returns the earliest CPU generation that supports the instruction set. +llvm::StringRef CPUTargetFromMaxFeature(CPUFeature max_feature) { + switch (max_feature) { + case CPUFeature::SSE4_2: + return "nehalem"; + case CPUFeature::AVX: + return "sandybridge"; + case CPUFeature::AVX2: + return "haswell"; + case CPUFeature::AVX512F: + return "skylake-avx512"; + case CPUFeature::AVX512_VNNI: + return "cascadelake"; + case CPUFeature::AVX512_BF16: + return "cooperlake"; + case CPUFeature::AMX_BF16: + case CPUFeature::AMX_INT8: + return "sapphirerapids"; + case CPUFeature::AMX_FP16: + return "graniterapids"; + default: + LOG(FATAL) << "Unsupported max feature: " << max_feature; + } +} + } // namespace +std::optional ISAStringToFeature(const std::string feature_string) { + if (feature_string.empty()) return std::nullopt; + + // Non-exhaustive list of CPU features. (Only the ones we care about.) + // TODO(penporn): Handle ARM + static absl::flat_hash_map* x86 = [] { + return new absl::flat_hash_map( + {{"SSE4_2", CPUFeature::SSE4_2}, + {"AVX", CPUFeature::AVX}, + {"AVX2", CPUFeature::AVX2}, + {"AVX512", CPUFeature::AVX512F}, + {"AVX512_VNNI", CPUFeature::AVX512_VNNI}, + {"AVX512_BF16", CPUFeature::AVX512_BF16}, + {"AMX", CPUFeature::AMX_BF16}, // Includes AMX_INT8. + {"AMX_FP16", CPUFeature::AMX_FP16}}); + }(); + + // Assume that `feature_string` always contains all uppercase letters. + if (auto it = x86->find(feature_string); it != x86->end()) return it->second; + LOG(WARNING) << "Unknown CPU ISA: " << feature_string; + return std::nullopt; +} + +// Disable any feature that is newer than `max_feature`. +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +bool ShouldEnableCPUFeature(const llvm::StringRef feature, + const CPUFeature& max_feature) { + // TODO(penporn): Figure out where to put AVX10. + switch (max_feature) { + case CPUFeature::SSE4_2: + if (feature.starts_with("avx") || feature == "f16c" || + feature == "vpclmulqdq" || feature == "vaes") { + return false; + } + [[fallthrough]]; + case CPUFeature::AVX: + if (feature.starts_with("avx2") || feature.starts_with("fma")) { + return false; + } + [[fallthrough]]; + case CPUFeature::AVX2: + if (feature.starts_with("avx512") || feature == "evex512") return false; + [[fallthrough]]; + case CPUFeature::AVX512F: + if (feature == "avx512vnni") return false; + [[fallthrough]]; + case CPUFeature::AVX512_VNNI: + if (feature == "avx512bf16") return false; + [[fallthrough]]; + case CPUFeature::AVX512_BF16: + if (feature.starts_with("amx")) return false; + [[fallthrough]]; + case CPUFeature::AMX_INT8: + case CPUFeature::AMX_BF16: + if (feature == "amx-fp16") return false; + [[fallthrough]]; + default: + // Leave all other features enabled. + return true; + } +} + +std::vector DetectMachineAttributes( + std::optional max_feature, bool& features_filtered) { + std::vector result; + features_filtered = false; + // We only have x86 constraints. Skip the check if we are on non-x86 CPUs. + const bool no_feature_constraint = + !max_feature.has_value() || !tsl::port::IsX86CPU(); + for (const auto& [feature, enabled] : llvm::sys::getHostCPUFeatures()) { + bool should_enable = + enabled && (no_feature_constraint || + ShouldEnableCPUFeature(feature, *max_feature)); + result.push_back((should_enable ? '+' : '-') + std::string(feature)); + features_filtered |= (should_enable != enabled); + } + return result; +} + +std::vector DetectMachineAttributes() { + bool features_filtered = false; + return DetectMachineAttributes(std::nullopt, features_filtered); +} + /*static*/ std::unique_ptr SimpleOrcJIT::InferTargetMachineForJIT( - const llvm::TargetOptions& target_options, - llvm::CodeGenOptLevel opt_level) { - std::vector attrs = DetectMachineAttributes(); + const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level, + std::string max_cpu_isa) { + bool features_filtered = false; + std::optional max_feature = ISAStringToFeature(max_cpu_isa); + std::vector attrs = + DetectMachineAttributes(max_feature, features_filtered); llvm::SmallVector llvm_attrs(attrs.begin(), attrs.end()); + // If `max_feature` is newer than the host CPU, we should keep the host CPU + // name, e.g., we don't want to set the target CPU to Skylake when we are on + // a Broadwell host. + llvm::StringRef target_cpu = features_filtered + ? CPUTargetFromMaxFeature(*max_feature) + : llvm::sys::getHostCPUName(); std::unique_ptr target_machine( llvm::EngineBuilder() .setTargetOptions(target_options) .setOptLevel(opt_level) .selectTarget( /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/llvm::sys::getHostCPUName(), + /*MCPU=*/target_cpu, /*MAttrs=*/llvm_attrs)); CHECK(target_machine != nullptr); return target_machine; } static CompilerFunctor::TargetMachineBuilder CreateTargetMachineBuilder( - llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level) { - return [target_options, opt_level]() { - return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level); + llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level, + std::string max_cpu_isa) { + return [target_options, opt_level, max_cpu_isa]() { + return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level, + max_cpu_isa); }; } @@ -338,9 +454,9 @@ SimpleOrcJIT::SimpleOrcJIT( LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs) + size_t num_jit_dylibs, std::string max_cpu_isa) : target_machine_builder_( - CreateTargetMachineBuilder(target_options, opt_level)), + CreateTargetMachineBuilder(target_options, opt_level, max_cpu_isa)), target_machine_(target_machine_builder_()), target_triple_(target_machine_->getTargetTriple()), data_layout_(target_machine_->createDataLayout()), @@ -426,7 +542,7 @@ llvm::Expected> SimpleOrcJIT::Create( LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs) { + size_t num_jit_dylibs, std::string max_cpu_isa) { auto SSP = std::make_shared(); auto target_process_control = llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP)); @@ -441,7 +557,7 @@ llvm::Expected> SimpleOrcJIT::Create( target_options, opt_level, optimize_for_size, disable_expensive_passes, disable_slp_vectorizer, fast_math_flags, std::move(pre_optimization_hook), std::move(post_optimization_hook), std::move(post_codegen_hook), - num_jit_dylibs); + num_jit_dylibs, std::move(max_cpu_isa)); } llvm::orc::ExecutorSymbolDef SimpleOrcJIT::ResolveRuntimeSymbol( diff --git a/xla/service/cpu/simple_orc_jit.h b/xla/service/cpu/simple_orc_jit.h index 9adec42216cd5..52f7c9204849f 100644 --- a/xla/service/cpu/simple_orc_jit.h +++ b/xla/service/cpu/simple_orc_jit.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -47,6 +48,7 @@ limitations under the License. #include "llvm/TargetParser/Triple.h" #include "xla/service/cpu/compiler_functor.h" #include "xla/service/llvm_compiler.h" +#include "tsl/platform/cpu_info.h" namespace xla::cpu { @@ -79,7 +81,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs = 1); + size_t num_jit_dylibs = 1, std::string max_cpu_isa = ""); static llvm::Expected> Create( const llvm::TargetOptions& target_options, @@ -90,7 +92,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs = 1); + size_t num_jit_dylibs = 1, std::string max_cpu_isa = ""); ~SimpleOrcJIT() override; @@ -117,7 +119,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { // the current machine. static std::unique_ptr InferTargetMachineForJIT( const llvm::TargetOptions& target_options, - llvm::CodeGenOptLevel opt_level); + llvm::CodeGenOptLevel opt_level, std::string max_cpu_isa = ""); int64_t SizeOfGeneratedCodeInBytes() const { return size_of_generated_code_in_bytes_; @@ -167,6 +169,17 @@ class SimpleOrcJIT : public llvm::JITEventListener { llvm::JITEventListener* perf_jit_event_listener_; }; +std::optional ISAStringToFeature( + std::string feature_string); + +bool ShouldEnableCPUFeature(llvm::StringRef feature, + const tsl::port::CPUFeature& max_feature); + +std::vector DetectMachineAttributes( + std::optional max_feature, bool& features_filtered); + +// TODO(penporn): PJRT's CPU client also calls this function. We should +// make it get the same filtered attributes according to the `max_isa` setting. std::vector DetectMachineAttributes(); } // namespace xla::cpu diff --git a/xla/service/cpu/tests/BUILD b/xla/service/cpu/tests/BUILD index 6b7e37b10f123..994bf3b3274f8 100644 --- a/xla/service/cpu/tests/BUILD +++ b/xla/service/cpu/tests/BUILD @@ -330,6 +330,7 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:simple_orc_jit", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", @@ -337,6 +338,7 @@ xla_cc_test( "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep + "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], diff --git a/xla/service/cpu/tests/cpu_vectorization_test.cc b/xla/service/cpu/tests/cpu_vectorization_test.cc index ec29e43b3aff9..0093635e23875 100644 --- a/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -15,21 +15,25 @@ limitations under the License. #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "llvm-c/Target.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/cpu/cpu_compiler.h" +#include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/cpu_info.h" #include "tsl/platform/test.h" namespace xla { @@ -140,6 +144,139 @@ INSTANTIATE_TEST_SUITE_P(CpuVectorizationTestInstantiation, ::testing::ValuesIn(CpuVectorizationTestCases), CpuVectorizationTest::Name); +struct MaxIsaTestSpec { + std::string max_isa; + std::string feature; + bool should_enable; +}; + +class MaxIsaTest : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static std::string Name( + const ::testing::TestParamInfo& info) { + // Test names cannot contain '-'. Replace it with '_'. + std::string feature = info.param.feature; + absl::c_replace_if( + feature, [](char c) { return c != '_' && !absl::ascii_isalnum(c); }, + '_'); + return absl::StrCat(info.param.max_isa, "_feature_", feature); + } +}; + +TEST_P(MaxIsaTest, ShouldEnableFeature) { + HloComputation::Builder builder(TestName()); + MaxIsaTestSpec spec = GetParam(); + + auto max_feature = ISAStringToFeature(spec.max_isa); + bool should_enable = ShouldEnableCPUFeature(spec.feature, *max_feature); + EXPECT_EQ(should_enable, spec.should_enable); +} + +std::vector GetMaxIsaTestCases() { + return std::vector({ + MaxIsaTestSpec{"AVX2", "avx", true}, + MaxIsaTestSpec{"AVX2", "avx2", true}, + MaxIsaTestSpec{"AVX2", "avx512f", false}, + MaxIsaTestSpec{"AVX2", "avx512vnni", false}, + MaxIsaTestSpec{"AVX2", "evex512", false}, + MaxIsaTestSpec{"AVX512", "avx512f", true}, + MaxIsaTestSpec{"AVX512", "avx512vnni", false}, + MaxIsaTestSpec{"AVX512", "amx-bf16", false}, + }); +} + +INSTANTIATE_TEST_SUITE_P(MaxIsaTestInstantiation, MaxIsaTest, + ::testing::ValuesIn(GetMaxIsaTestCases()), + MaxIsaTest::Name); + +struct JitVectorizationTestSpec { + HloOpcode opcode; + std::string max_isa; + std::string check_template; + int num_vector_elements; +}; + +class JitVectorizationTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static std::string Name( + const ::testing::TestParamInfo& info) { + std::string op_name(HloOpcodeString(info.param.opcode)); + op_name[0] = toupper(op_name[0]); + return absl::StrCat(op_name, "_max_", info.param.max_isa); + } + + private: + DebugOptions GetDebugOptionsForTest() override { + JitVectorizationTestSpec spec = GetParam(); + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_max_isa(spec.max_isa); + // For AVX512, we have to override the default `prefer_vector_width=256` + // setting. Otherwise, LLVM won't generate AVX512. + // TODO(penporn): Change the setting for actual AVX512 codegen too. + if (spec.max_isa == "AVX512") { + debug_options.set_xla_cpu_prefer_vector_width(512); + } + return debug_options; + } +}; + +TEST_P(JitVectorizationTest, JitUpToIsa) { + if (!tsl::port::IsX86CPU()) { + GTEST_SKIP() << "This feature only works for x86 CPUs."; + } + HloComputation::Builder builder(TestName()); + JitVectorizationTestSpec spec = GetParam(); + + // If the CPU doesn't have the `max_isa` feature, e.g., `max_isa=AVX512` but + // we are running on an AVX2 machine, update the `check_lines` accordingly. + using tsl::port::CPUFeature; + auto feature = ISAStringToFeature(spec.max_isa); + if (!tsl::port::TestCPUFeature(*feature)) { + if (tsl::port::TestCPUFeature(CPUFeature::AVX)) { + spec.num_vector_elements = 8; + } else { + spec.num_vector_elements = 4; + } + } + std::string check_lines = absl::StrReplaceAll( + spec.check_template, {{"%d", absl::StrCat(spec.num_vector_elements)}}); + + // Build HLO module. + auto shape = ShapeUtil::MakeShape(F32, {1024}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, spec.opcode, a, b)); + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyIr(std::move(hlo_module), check_lines, + /*match_optimized_ir=*/true); +} + +std::vector GetJitVectorizationTestCases() { + return std::vector({ + JitVectorizationTestSpec{HloOpcode::kMultiply, "SSE4_2", + R"(CHECK: fmul <%d x float>)", 4}, + JitVectorizationTestSpec{HloOpcode::kMultiply, "AVX2", + R"(CHECK: fmul <%d x float>)", 8}, + JitVectorizationTestSpec{HloOpcode::kMultiply, "AVX512", + R"(CHECK: fmul <%d x float>)", 16}, + }); +} + +INSTANTIATE_TEST_SUITE_P(JitVectorizationTestInstantiation, + JitVectorizationTest, + ::testing::ValuesIn(GetJitVectorizationTestCases()), + JitVectorizationTest::Name); + } // namespace } // namespace cpu } // namespace xla diff --git a/xla/xla.proto b/xla/xla.proto index c8900a65753b3..7c45764f7197d 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -107,6 +107,11 @@ message DebugOptions { // value is `256` (AVX2 on x86 platforms). int32 xla_cpu_prefer_vector_width = 308; + // When set, XLA:CPU will only generate code up to the specified ISA. + // (It will not use newer ISAs.) Using the string format allows us to extend + // the flag for more flexible control if necessary. + string xla_cpu_max_isa = 333; + // go/keep-sorted end //--------------------------------------------------------------------------// @@ -989,7 +994,7 @@ message DebugOptions { // loop by a factor of two if a collective op is present. bool xla_gpu_enable_heuristic_pass_configuration = 332; - // Next id: 333 + // Next id: 334 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.