Skip to content

Commit

Permalink
Fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
MacDue committed Jul 19, 2024
1 parent 09b0fcd commit 7af1229
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
RewriterBase &rewriter);

/// Structure to hold the range [vscaleMin, vscaleMax] `vector.vscale` can take.
// Structure to hold the range of `vector.vscale`.
struct VscaleRange {
unsigned vscaleMin;
unsigned vscaleMax;
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
Expand Down Expand Up @@ -105,10 +113,14 @@ void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
return;

OpBuilder::InsertionGuard g(rewriter);

// Build worklist so we can safely insert new ops in
// `resolveAllTrueCreateMaskOp()`.
SmallVector<vector::CreateMaskOp> worklist;
function.walk([&](vector::CreateMaskOp createMaskOp) {
worklist.push_back(createMaskOp);
});

rewriter.setInsertionPointToStart(&function.front());
for (auto mask : worklist)
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/Vector/eliminate-masks.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
%extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32>
%output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
// 1. Extract a slice.
%extracted_slice_1 = tensor.extract_slice %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
%extracted_slice_1 = tensor.extract_slice %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>

// 2. Create a mask for the slice.
%dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor<?xf32>
Expand All @@ -30,7 +30,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
%write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32>

// 5. Insert and yield the new tensor value.
%result = tensor.insert_slice %write into %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
%result = tensor.insert_slice %write into %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
scf.yield %result : tensor<1x?xf32>
}
"test.some_use"(%output_tensor) : (tensor<1x?xf32>) -> ()
Expand All @@ -57,8 +57,8 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
%mask = vector.create_mask %dim : vector<[4]xi1>
"test.some_use"(%mask) : (vector<[4]xi1>) -> ()
// !!! Here the size of the mask could shrink in the next iteration.
%next_num_els = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
%new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_els] [1] : tensor<1000xf32> to tensor<?xf32>
%next_num_elts = affine.min affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
%new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_elts] [1] : tensor<1000xf32> to tensor<?xf32>
scf.yield %new_extracted_slice : tensor<?xf32>
}
"test.some_use"(%slice) : (tensor<?xf32>) -> ()
Expand Down Expand Up @@ -110,8 +110,8 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
%c4 = arith.constant 4 : index
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
// This is _very_ simple but since addi is not a constant value bounds will
// be used to resolve it.
// This is _very_ simple but since tensor.dim is not a constant value bounds
// will be used to resolve it.
%dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
%mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
"test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()
Expand Down
9 changes: 2 additions & 7 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,10 @@ struct TestEliminateVectorMasks
: PassWrapper(pass) {}

Option<unsigned> vscaleMin{
*this, "vscale-min",
llvm::cl::desc(
"Minimum value `vector.vscale` can possibly be at runtime."),
*this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
llvm::cl::init(1)};

Option<unsigned> vscaleMax{
*this, "vscale-max",
llvm::cl::desc(
"Maximum value `vector.vscale` can possibly be at runtime."),
*this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
llvm::cl::init(16)};

StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
Expand Down

0 comments on commit 7af1229

Please sign in to comment.