diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt index c9d962d2de23fa..dde561e9dbf4dc 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass MLIRFuncToSPIRV MLIRIndexToSPIRV MLIRIR + MLIRMemRefToSPIRV MLIRPass MLIRRewrite MLIRSCFToSPIRV diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index 4694a147e1e94d..fbf80a8b510dff 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" @@ -62,12 +63,24 @@ struct ConvertToSPIRVPass final RewritePatternSet patterns(context); ScfToSPIRVContext scfToSPIRVContext; + // Map MemRef memory space to SPIR-V storage class. + spirv::TargetEnv targetEnv(targetAttr); + bool targetEnvSupportsKernelCapability = + targetEnv.allows(spirv::Capability::Kernel); + spirv::MemorySpaceToStorageClassMap memorySpaceMap = + targetEnvSupportsKernelCapability + ? spirv::mapMemorySpaceToOpenCLStorageClass + : spirv::mapMemorySpaceToVulkanStorageClass; + spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); + spirv::convertMemRefTypesAndAttrs(op, converter); + // Populate patterns for each dialect. arith::populateCeilFloorDivExpandOpsPatterns(patterns); arith::populateArithToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); populateFuncToSPIRVPatterns(typeConverter, patterns); index::populateIndexToSPIRVPatterns(typeConverter, patterns); + populateMemRefToSPIRVPatterns(typeConverter, patterns); populateVectorToSPIRVPatterns(typeConverter, patterns); populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir new file mode 100644 index 00000000000000..5af8bfc842ea13 --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse %s | FileCheck %s + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @load_store_float_rank_zero +// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr [0])>, StorageBuffer> +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32 +// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : f32 +// CHECK: spirv.Return +func.func @load_store_float_rank_zero(%arg0: memref, %arg1: memref) { + %0 = memref.load %arg0[] : memref + memref.store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: @load_store_int_rank_one +// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr [0])>, StorageBuffer>, %[[ARG2:.*]]: i32 +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32 +// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32 +// CHECK: spirv.Return +func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2 : index) { + %0 = memref.load %arg0[%arg2] : memref<4xi32> + memref.store %0, %arg1[%arg2] : memref<4xi32> + return +} + +// CHECK-LABEL: @load_store_larger_memref +// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr [0])>, StorageBuffer>, %[[ARG2:.*]]: i32 +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32 +// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32 +// CHECK: spirv.Return +func.func @load_store_larger_memref(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2 : index) { + %0 = memref.load %arg0[%arg2] : memref<8xi32> + memref.store %0, %arg1[%arg2] : memref<8xi32> + return +} + + +// CHECK-LABEL: @load_store_vector +// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr, stride=16> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr, stride=16> [0])>, StorageBuffer> +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : vector<4xi32> +// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32 +// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : vector<4xi32> +// CHECK: spirv.Return +func.func @load_store_vector(%arg0: memref>, %arg1: memref>) { + %0 = memref.load %arg0[] : memref> + memref.store %0, %arg1[] : memref> + return +} + +} // end module diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 84938231140127..6373e53b16c975 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8316,6 +8316,7 @@ cc_library( ":FuncToSPIRV", ":IR", ":IndexToSPIRV", + ":MemRefToSPIRV", ":Pass", ":Rewrite", ":SCFToSPIRV",