-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add masking support to memory ops #69148
[mlir][ArmSME] Add masking support to memory ops #69148
Conversation
This patch prefixes tile slice layout with `layout` in the assemblyFormat: - `<vertical>` -> `layout<vertical>` - `<horizontal>` -> `layout<horizontal>` The reason for this change is the current format doesn't play nicely with additional optional operands, required to support padding and masking (#69148), as it becomes ambiguous. This affects the the following ops: - arm_sme.tile_load - arm_sme.tile_store - arm_sme.load_tile_slice - arm_sme.store_tile_slice
Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.
This extends the lowering of vector.transfer_read -> arm_sme.tile_load lowering to propagate pad and mask. The restriction on the transfer_read being a transposition is also removed, identity maps are lowered to normal horizontal loads.
This patch extends ArmSMEToSCF to support lowering of masked tile_load ops. Only masks created by 'vector.create_mask' are currently supported. There are two lowerings, one for pad of constant zero and another for non-zero pad. For the following example: %pad = arith.constant 0 : i32 %num_rows = arith.constant 2 : index %num_cols = arith.constant 4 : index %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1> %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32> The former (constant non-zero pad) is lowered as follows: --------------------------------------------------------- %tile = arm_sme.zero : vector<[4]x[4]xi32> %num_cols = vector.create_mask %c4 : vector<[4]xi1> scf.for %slice_idx = %c0 to %num_rows step %c1 %tile_update = arm_sme.load_tile_slice %src[%slice_idx], %num_cols, %tile, %tile_slice_idx : memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32> The tile is zeroed the satisfy the padding and only active rows are loaded. The latter (non-zero pad) is lowered as follows: ------------------------------------------------ scf.for %slice_idx = %c0 to %num_tile_slices step %c1 { %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index %slice = scf.if %row_is_active -> vector<[4]xf32> { %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d : memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> scf.yield %slice : vector<[4]xf32> } else { scf.yield %pad_1d : vector<[4]xf32> } arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx : vector<[4]xi32> into vector<[4]x[4]xi32> The scalar pad is broadcast to a 1-D vector and a regular 'vector.masked_load' (will be lowered to SVE, not SME) loads each slice for active rows, with padding specified as a passthru. For non-active rows the slice is the 1-D pad. The resulting slice is inserted into the tile with 'arm_sme.move_vector_to_tile_slice'.
This patch extends ArmSMEToSCF to support lowering of masked tile_store ops. Only masks created by 'vector.create_mask' are currently supported. Example: %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32> Produces: %num_rows = arith.constant 3 : index %num_cols = vector.create_mask %c2 : vector<[4]xi1> scf.for %slice_idx = %c0 to %num_rows step %c1 arm_sme.store_tile_slice %tile, %slice_idx, %num_cols, %dest[%slice_idx, %c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
This patch extends the lowering of vector.transfer_write in VectorToArmSME to support in-flight transpose via SME vertical store.
5a336ee
to
14aac43
Compare
Makes sense to me, I don't have any specific suggestions. We can dive into details separately for every patch. Thanks! |
PRs
|
All changes have been landed in separate PRs, closing. |
This patch series adds masking support to the ArmSME memory ops, as well as lowerings from Vector
transfer_read
andvector.transfer_write
. Thetransfer_read
to ArmSME is more complete thantransfer_write
, for the latter the VectorToArmSME lowering still needs fleshing out to support in-flight transpose via vertical load and integration tests need to be added.This support is part of a wider effort to lower linalg.matmul to SME. There's a lot of changes here so I don't expect this to be reviewed as a whole, hence why I've created this as a draft PR. I plan to create separate PRs for each commit.