Skip to content

Commit

Permalink
Allow vectorization in DynamicUpdateSlice in-place emitter.
Browse files Browse the repository at this point in the history
We can use the same conditions for when to allow vectorization as for loop
fusion.

PiperOrigin-RevId: 678591064
  • Loading branch information
akuegel authored and Google-ML-Automation committed Sep 25, 2024
1 parent e0b9573 commit 46c95ae
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 4 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ LaunchDimensions MlirInPlaceDynamicUpdateSliceFusion::launch_dimensions()
const {
const auto& update_shape =
dus_ops_.front().GetOperand(kDUSUpdateIndex).shape();
return CalculateLaunchDimensions(update_shape, analysis_.device_info());
return CalculateLaunchDimensions(update_shape, analysis_.device_info(),
config_);
}

std::optional<IndexingMap>
Expand All @@ -84,7 +85,7 @@ MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing(
// It is guaranteed that all DUS ops have the same output shape at this point.
const auto& update_shape =
dus_ops_.front().GetOperand(kDUSUpdateIndex).shape();
return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1,
return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor,
update_shape, indexing_context);
}

Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/fusions/mlir/computation_partitioner.h"
#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h"
#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
Expand All @@ -47,8 +48,9 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase {
explicit MlirInPlaceDynamicUpdateSliceFusion(
const HloFusionAnalysis& analysis)
: analysis_(analysis),
dus_ops_(
GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {}
dus_ops_(GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())),
config_(ComputeLoopFusionConfig(
analysis, dus_ops_[0].instruction().operand(1)->shape())) {}

LaunchDimensions launch_dimensions() const override;

Expand Down Expand Up @@ -77,6 +79,7 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase {
private:
const HloFusionAnalysis& analysis_;
std::vector<HloInstructionAdaptor> dus_ops_;
LaunchDimensionsConfig config_;
};

} // namespace gpu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \
// RUN: -xla-gpu-test-transform-loops | FileCheck %s
// RUN: test_correctness %s --bijection_inputs=dus:1

dus {
%input = f32[40,40,300] parameter(0)
%update = f32[1,1,40] parameter(1)
%idx = s32[] parameter(2)
%zero = s32[] constant(0)
ROOT dus = f32[40,40,300] dynamic-update-slice(%input, %update, %idx, %zero, %zero)
}

// CHECK-NOT: vector.transfer_read {{.*}} vector<4xf32>
// CHECK-NOT: vector.transfer_write {{.*}} vector<4xf32>
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \
// RUN: -xla-gpu-test-transform-loops | FileCheck %s
// RUN: test_correctness %s --bijection_inputs=dus:1

dus {
%input = f32[40,40,300] parameter(0)
%update = f32[20,40,300] parameter(1)
%idx = s32[] parameter(2)
%zero = s32[] constant(0)
ROOT dus = f32[40,40,300] dynamic-update-slice(%input, %update, %idx, %zero, %zero)
}

// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: vector.transfer_write {{.*}} vector<4xf32>

0 comments on commit 46c95ae

Please sign in to comment.