Skip to content

Commit

Permalink
Use ShardyCallInliner in XLA GPU pipeline.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679665714
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 27, 2024
1 parent 22004eb commit 0131cbf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ cc_library(
"//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
"//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd/shardy:shardy_call_inliner",
"//xla/service/spmd:collective_permute_motion",
"//xla/service:algebraic_simplifier",
"//xla/service:all_gather_broadcast_reorder",
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ limitations under the License.
#include "xla/service/slice_sinker.h"
#include "xla/service/slow_operation_alarm.h"
#include "xla/service/sort_simplifier.h"
#include "xla/service/spmd/shardy/shardy_call_inliner.h"
#include "xla/service/stable_sort_expander.h"
#include "xla/service/stochastic_convert_decomposer.h"
#include "xla/service/sub_byte_normalization.h"
Expand Down Expand Up @@ -551,7 +552,7 @@ absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) {
// passes.
pre_spmd_pipeline.AddPass<CuDnnCustomCallConverter>();
pre_spmd_pipeline.AddPass<ConvertMemoryPlacementToInternalAnnotations>();
pre_spmd_pipeline.AddPass<CallInliner>();
pre_spmd_pipeline.AddPass<ShardyCallInliner>();
pre_spmd_pipeline.AddPass<ZeroSizedHloElimination>();
pre_spmd_pipeline.AddPass<ConditionalCanonicalizer>();

Expand Down Expand Up @@ -708,7 +709,7 @@ absl::Status RunOptimizationPasses(
pipeline.AddPass<DynamicIndexSplitter>();

// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();
pipeline.AddPass<ShardyCallInliner>();

pipeline.AddPass<StochasticConvertDecomposer>();

Expand Down Expand Up @@ -1585,7 +1586,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
options.key_value_store,
gpu_target_config.device_description.runtime_version()));
// Inline back the calls which have better performance with cuBLAS.
pipeline.AddPass<CallInliner>();
pipeline.AddPass<ShardyCallInliner>();
// TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated
// here for possibly better cuBLAS performance.
AddGemmRewriterPasses(pipeline, debug_options, gpu_version,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/spmd/shardy/shardy_call_inliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ namespace xla {
// sure we inline all functions except for the shmap_body's when using
// Shardy. When Shardy is disabled, then we have the same behavior as
// CallInliner.
//
// TODO(bartchr): Move the logic in here into the regular XLA `CallInliner`.
// Shardy is now proven out so we should have the parent `CallInliner` handle
// this.
class ShardyCallInliner : public CallInliner {
public:
using CallInliner::CallInliner;
Expand Down

0 comments on commit 0131cbf

Please sign in to comment.