Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add masking support to memory ops #69148

Closed

Conversation

c-rhodes
Copy link
Collaborator

This patch series adds masking support to the ArmSME memory ops, as well as lowerings from Vector transfer_read and vector.transfer_write. The transfer_read to ArmSME is more complete than transfer_write, for the latter the VectorToArmSME lowering still needs fleshing out to support in-flight transpose via vertical load and integration tests need to be added.

This support is part of a wider effort to lower linalg.matmul to SME. There's a lot of changes here so I don't expect this to be reviewed as a whole, hence why I've created this as a draft PR. I plan to create separate PRs for each commit.

@c-rhodes
Copy link
Collaborator Author

cc @banach-space @MacDue

c-rhodes added a commit that referenced this pull request Oct 16, 2023
This patch prefixes tile slice layout with `layout` in the
assemblyFormat:

  - `<vertical>`   -> `layout<vertical>`
  - `<horizontal>` -> `layout<horizontal>`

The reason for this change is the current format doesn't play nicely
with additional optional operands, required to support padding and
masking (#69148), as it becomes ambiguous.

This affects the the following ops:

  - arm_sme.tile_load
  - arm_sme.tile_store
  - arm_sme.load_tile_slice
  - arm_sme.store_tile_slice
Padding and mask are optional, but if one is specified both must be
specified. This is consistent with vector.transfer_read.
This extends the lowering of vector.transfer_read -> arm_sme.tile_load
lowering to propagate pad and mask.

The restriction on the transfer_read being a transposition is also
removed, identity maps are lowered to normal horizontal loads.
This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                          vector<[4]x[4]xi32>

The former (constant non-zero pad) is lowered as follows:
---------------------------------------------------------

  %tile = arm_sme.zero : vector<[4]x[4]xi32>
  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:
------------------------------------------------

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector<[4]xi32> into vector<[4]x[4]xi32>

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.
This patch extends ArmSMEToSCF to support lowering of masked tile_store
ops. Only masks created by 'vector.create_mask' are currently supported.

Example:

  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>,
vector<[4]x[4]xi32>

Produces:

  %num_rows = arith.constant 3 : index
  %num_cols = vector.create_mask %c2 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    arm_sme.store_tile_slice %tile, %slice_idx, %num_cols, %dest[%slice_idx, %c0]
      : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
@c-rhodes c-rhodes force-pushed the mlir-arm-sme-transfer-readwrite-masking branch from 5a336ee to 14aac43 Compare October 16, 2023 15:08
MacDue pushed a commit that referenced this pull request Oct 25, 2023
)

This is used in #69148 when lowering masked tile_store with non-zero
pad, see #69148

This updates:
 * `arm_sme.move_vector_to_tile_slice`
 * `arm_sme.move_tile_slice_to_vector`
@banach-space
Copy link
Contributor

Makes sense to me, I don't have any specific suggestions. We can dive into details separately for every patch. Thanks!

@c-rhodes
Copy link
Collaborator Author

c-rhodes commented Nov 3, 2023

PRs

#69195 - [mlir][ArmSME] Add optional padding and mask operands to tile_load
#70655 - [mlir][ArmSME] Add mask operand to load_tile_slice
#70814 - [mlir][ArmSME] Propagate pad and mask in vector.transfer_read lowering
#69186 - [mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops
#70915 - [mlir][ArmSME] Add support for lowering masked tile_load ops
#70657 - [mlir][ArmSME] Add optional mask operand to tile_store
#70838 - [mlir][ArmSME] Add mask operand to store_tile_slice
#71180 - [mlir][ArmSME] Add support for lowering masked tile_store ops
#71181 - [mlir][ArmSME] Lower transfer_write + transpose to vertical store

@c-rhodes
Copy link
Collaborator Author

All changes have been landed in separate PRs, closing.

@c-rhodes c-rhodes closed this Nov 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants