Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
c-rhodes committed Oct 31, 2023
1 parent 69e1891 commit e7b1522
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"mask has i1 element type and same shape as result",
"mask has i1 element type and is a slice of the result",
"result", "mask",
"VectorType("
"VectorType::Builder("
Expand Down Expand Up @@ -434,7 +434,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from">:$base, AnyVector:$mask,
Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
if (tileLoadOp.getMask())
// TODO: add masked patterns.
return rewriter.notifyMatchFailure(
tileLoadOp, "op has mask, needs masked pattern(s)");

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/ArmSME/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64

func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{op failed to verify that mask has i1 element type and same shape as result}}
// expected-error@+1 {{op failed to verify that mask has i1 element type and is a slice of the result}}
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[2]xi1>, vector<[16]x[16]xi8>
return
}
Expand Down

0 comments on commit e7b1522

Please sign in to comment.