From e7b1522ff14a02d9e1c085b2e09ac28e5411a734 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Tue, 31 Oct 2023 11:28:21 +0000 Subject: [PATCH] address comments --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 4 ++-- mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 1 + mlir/test/Dialect/ArmSME/invalid.mlir | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 655fb0f55e1866..37a2257a0015ce 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -392,7 +392,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> { def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ AllTypesMatch<["tile", "result"]>, TypesMatchWith< - "mask has i1 element type and same shape as result", + "mask has i1 element type and is a slice of the result", "result", "mask", "VectorType(" "VectorType::Builder(" @@ -434,7 +434,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ ``` }]; let arguments = (ins - Arg:$base, AnyVector:$mask, + Arg:$base, SVEPredicate:$mask, SMETile:$tile, Variadic:$indices, Index:$tile_slice_index, ArmSME_TileSliceLayoutAttr:$layout ); diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 9cfb13216d9bfe..50cc818f1ffc09 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -80,6 +80,7 @@ struct TileLoadOpConversion : public OpRewritePattern { LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, PatternRewriter &rewriter) const override { if (tileLoadOp.getMask()) + // TODO: add masked patterns. return rewriter.notifyMatchFailure( tileLoadOp, "op has mask, needs masked pattern(s)"); diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 04daf2d94b6f98..1d6386bbf3828f 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -159,7 +159,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64 func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { %c0 = arith.constant 0 : index - // expected-error@+1 {{op failed to verify that mask has i1 element type and same shape as result}} + // expected-error@+1 {{op failed to verify that mask has i1 element type and is a slice of the result}} %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[2]xi1>, vector<[16]x[16]xi8> return }