diff --git a/oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll b/oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll index 0eb574e0f99..6ba8996ac25 100644 --- a/oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll +++ b/oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll @@ -1,12 +1,14 @@ #include "OneFlow/OneFlowOps.td" Constraint IsFuncArguments(value: Value) [{ - return success(llvm::dyn_cast(value) != nullptr); + return success(llvm::dyn_cast(value)); }]; Pattern { + arg: Value; let alloc = op(); - let copy = op(alloc.0, arg: IsFuncArguments); + let copy = op(alloc.0, arg); + IsFuncArguments(arg); rewrite alloc with { erase copy; diff --git a/oneflow/ir/lib/OneFlow/Passes.cpp b/oneflow/ir/lib/OneFlow/Passes.cpp index ba067c0894f..535c18aeabf 100644 --- a/oneflow/ir/lib/OneFlow/Passes.cpp +++ b/oneflow/ir/lib/OneFlow/Passes.cpp @@ -81,6 +81,7 @@ limitations under the License. #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -121,7 +122,6 @@ limitations under the License. #include #ifdef WITH_MLIR_CUDA_CODEGEN -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -1016,21 +1016,19 @@ struct KernelLaunchWithCudaGraphPattern : public KernelLaunchSimplePattern { } }; -void AddLowerToLinalgMemRefPasses(PassManager& pm) { - pm.addPass(createConvertToSignlessForTosaPass()); // convert-to-signless-for-tosa - pm.addNestedPass(LLVM::createRequestCWrappersPass()); // llvm-request-c-wrappers - pm.addPass(createLowerOneFlowToTosaPass()); // lower-oneflow-to-tosa - pm.addNestedPass( - tosa::createTosaMakeBroadcastablePass()); // tosa-make-broadcastable - pm.addPass(createCSEPass()); // cse - pm.addNestedPass(tosa::createTosaToLinalg()); // tosa-to-linalg-on-tensors - pm.addNestedPass(tosa::createTosaToTensor()); // tosa-to-tensor - pm.addNestedPass( - createLinalgElementwiseOpFusionPass()); // linalg-fuse-elementwise-ops +void AddLoweringToLinalgMemRefPasses(PassManager& pm) { + pm.addPass(createConvertToSignlessForTosaPass()); + pm.addNestedPass(LLVM::createRequestCWrappersPass()); + pm.addPass(createLowerOneFlowToTosaPass()); + pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); + pm.addPass(createCSEPass()); + pm.addNestedPass(tosa::createTosaToLinalg()); + pm.addNestedPass(tosa::createTosaToTensor()); + pm.addNestedPass(createLinalgElementwiseOpFusionPass()); // TODO: more optimization pass // Note: OneShot bufferization with result extract realization. - pm.addPass(bufferization::createEmptyTensorEliminationPass()); // eliminate-empty-tensors - pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); // empty-tensor-to-alloc-tensor + pm.addPass(bufferization::createEmptyTensorEliminationPass()); + pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); auto oneshot_bufferize = bufferization::createOneShotBufferizePass(); CHECK( @@ -1038,57 +1036,75 @@ void AddLowerToLinalgMemRefPasses(PassManager& pm) { ->initializeOptions("create-deallocs=0 bufferize-function-boundaries allow-return-allocs") .succeeded()); pm.addPass(std::move(oneshot_bufferize)); - pm.addPass(bufferization::createBufferResultsToOutParamsPass()); // buffer-results-to-out-params - pm.addPass(mlir::oneflow::createEliminateAllocOpsPass()); // eliminate-alloc-ops - pm.addPass(createCanonicalizerPass()); // canonicalize + pm.addPass(bufferization::createBufferResultsToOutParamsPass()); + pm.addPass(mlir::oneflow::createEliminateAllocOpsPass()); + pm.addPass(createCanonicalizerPass()); } LogicalResult LowerModuleToLLVM(mlir::MLIRContext* context, ModuleOp module) { mlir::PassManager pm(context); mlir::oneflow::CheckEnableIRPrinting(pm); - AddLowerToLinalgMemRefPasses(pm); - pm.addNestedPass(createConvertLinalgToLoopsPass()); // convert-linalg-to-loops - pm.addNestedPass(createConvertSCFToCFPass()); // convert-scf-to-cf - pm.addPass(createConvertLinalgToLLVMPass()); // convert-linalg-to-llvm - pm.addNestedPass(createFoldAllocToSubviewPass()); // fold-alloc-to-subview - pm.addPass(createInsertOneFlowMemPoolPass()); // insert-ofmempool - pm.addPass(createAppendOneFlowStreamPass()); // append-ofstream - pm.addPass(createFinalizeMemRefToLLVMConversionPass()); // convert-memref-to-llvm - pm.addPass(createConvertFuncToLLVMPass()); // convert-func-to-llvm - pm.addPass(memref::createExpandStridedMetadataPass()); // expand-strided-metadata - pm.addPass(createFinalizeMemRefToLLVMConversionPass()); // convert-memref-to-llvm - pm.addPass(createReconcileUnrealizedCastsPass()); // reconcile-unrealized-casts + AddLoweringToLinalgMemRefPasses(pm); + pm.addNestedPass(createConvertLinalgToLoopsPass()); + pm.addNestedPass(createConvertSCFToCFPass()); + pm.addNestedPass(createFoldAllocToSubviewPass()); + pm.addPass(createInsertOneFlowMemPoolPass()); + pm.addPass(createAppendOneFlowStreamPass()); + pm.addPass(memref::createExpandOpsPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertLinalgToLLVMPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); return pm.run(module); } #ifdef WITH_MLIR_CUDA_CODEGEN +void AddLoweringLinalgOnBufferToGpuWithStdPasses(PassManager& pm) { + pm.addNestedPass(createConvertLinalgToParallelLoopsPass()); + pm.addNestedPass(createGpuMapParallelLoopsPass()); + pm.addNestedPass(createParallelLoopToGpuPass()); + pm.addNestedPass(createGpuLauchSinkIndexComputationsPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createFoldAllocToSubviewPass()); + pm.addPass(createInsertOneFlowMemPoolPass()); + pm.addNestedPass(createConvertLinalgToLoopsPass()); + pm.addNestedPass(createConvertSCFToCFPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createGpuCopyArgPass()); +} + +void AddAdheringCubinToGpuModulePasses(PassManager& pm) { + pm.addNestedPass(createLowerAffinePass()); + pm.addNestedPass(createStripDebugInfoPass()); + pm.addNestedPass(createLowerGpuOpsToNVVMOpsPass()); + pm.addNestedPass(createNVVMToCubinPass()); +} + +void AddLoweringGpuToLLVMPasses(PassManager& pm) { + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(createLowerAffinePass()); + pm.addPass(createAppendOneFlowStreamPass()); + pm.addPass(createGpuToLLVMConversionPass()); + pm.addPass(createMgpuToOneFlowStreamPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); +} + LogicalResult LowerModuleToCUDALLVM(mlir::MLIRContext* context, ModuleOp module) { InitializeLLVMNVPTXBackend(); mlir::PassManager pm(context); mlir::oneflow::CheckEnableIRPrinting(pm); - AddLowerToLinalgMemRefPasses(pm); - pm.addNestedPass( - createConvertLinalgToParallelLoopsPass()); // convert-linalg-to-parallel-loops - pm.addNestedPass(createGpuMapParallelLoopsPass()); // gpu-map-parallel-loops - pm.addPass(createParallelLoopToGpuPass()); // convert-parallel-loops-to-gpu - pm.addPass(createGpuLauchSinkIndexComputationsPass()); - pm.addPass(createGpuKernelOutliningPass()); // gpu-kernel-outlining - pm.addPass(createCanonicalizerPass()); // canonicalize - pm.addNestedPass(createFoldAllocToSubviewPass()); // fold-alloc-to-subview - pm.addPass(createInsertOneFlowMemPoolPass()); // insert-ofmempool - // -pass-pipeline='gpu.module([PASS1][PASS2]...)' - pm.addNestedPass(createStripDebugInfoPass()); // strip-debuginfo - pm.addNestedPass(createLowerAffinePass()); // lower-affine - pm.addNestedPass(createLowerGpuOpsToNVVMOpsPass()); // convert-gpu-to-nvvm - pm.addNestedPass(createNVVMToCubinPass()); // out-of-tree-gpu-to-cubin - pm.addNestedPass(createGpuCopyArgPass()); // buffer-host-register - pm.addPass(createAppendOneFlowStreamPass()); // append-ofstream - pm.addPass(createGpuToLLVMConversionPass()); // gpu-to-llvm - pm.addPass(createMgpuToOneFlowStreamPass()); // gpu-to-llvm - pm.addPass(memref::createExpandStridedMetadataPass()); // expand-strided-metadata - pm.addPass(createFinalizeMemRefToLLVMConversionPass()); // convert-memref-to-llvm - pm.addPass(createReconcileUnrealizedCastsPass()); // reconcile-unrealized-casts + AddLoweringToLinalgMemRefPasses(pm); + AddLoweringLinalgOnBufferToGpuWithStdPasses(pm); + pm.addPass(memref::createExpandOpsPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createGpuKernelOutliningPass()); + AddAdheringCubinToGpuModulePasses(pm); + AddLoweringGpuToLLVMPasses(pm); return pm.run(module); } diff --git a/oneflow/ir/test/OneFlow/cuda_code_gen/test_matmul.py b/oneflow/ir/test/OneFlow/cuda_code_gen/test_matmul.py new file mode 100644 index 00000000000..7ae41688c2a --- /dev/null +++ b/oneflow/ir/test/OneFlow/cuda_code_gen/test_matmul.py @@ -0,0 +1,77 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s +# CHECK: jit + +import unittest +import numpy as np + +import os + + +import oneflow as flow +import oneflow.unittest + + +class MatMulModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.w = flow.nn.Parameter(flow.Tensor(5, 10)) + self.b = flow.nn.Parameter(flow.Tensor(10)) + + def forward(self, x): + return flow.matmul(x, self.w) + self.b + + +def do_matmul_graph(test_case, with_cuda=False): + x = flow.randn(2, 5) + module_to_run = MatMulModule() + if with_cuda: + x = x.cuda() + module_to_run = module_to_run.to("cuda") + y_eager = module_to_run(x) + + class GraphToRun(flow.nn.Graph): + def __init__(self): + super().__init__() + self.fw = module_to_run + + def build(self, x): + return self.fw(x) + + graph_to_run = GraphToRun() + y_lazy = graph_to_run(x) + test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) + + +@flow.unittest.skip_unless_1n1d() +class TestFuseCastScale(oneflow.unittest.MLIRTestCase): + def setUp(self): + os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" + os.environ["ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS"] = "1" + os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" + + def test_relu_graph(test_case): + import oneflow.sysconfig + + if oneflow.sysconfig.with_cuda(): + do_matmul_graph(test_case, True) + + do_matmul_graph(test_case) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/ir/test/Transform/lit.local.cfg b/oneflow/ir/test/Transform/lit.local.cfg new file mode 100644 index 00000000000..27bf52421be --- /dev/null +++ b/oneflow/ir/test/Transform/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.WITH_MLIR_CUDA_CODEGEN: + config.unsupported = True diff --git a/oneflow/ir/test/Transform/matmul.mlir b/oneflow/ir/test/Transform/matmul.mlir new file mode 100644 index 00000000000..e2dacd1640a --- /dev/null +++ b/oneflow/ir/test/Transform/matmul.mlir @@ -0,0 +1,36 @@ +// RUN: oneflow-opt %s --insert-ofmempool --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand --gpu-kernel-outlining \ +// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))' + +module { + func.func @JITOpGenerated0(%arg0: memref<5x10xf32, strided<[?, ?], offset: ?>>, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>>, %arg2: memref<2x10xf32>) attributes {llvm.emit_c_interface} { + %alloc = memref.alloc() : memref<512xi8> + %c0 = arith.constant 0 : index + %view = memref.view %alloc[%c0][] : memref<512xi8> to memref<1x2x10xf32> + %c10 = arith.constant 10 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0_0 = arith.constant 0 : index + %c5 = arith.constant 5 : index + %cst = arith.constant 0.000000e+00 : f32 + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] : memref<5x10xf32, strided<[?, ?], offset: ?>> into memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>> + %expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2]] : memref<2x5xf32, strided<[?, ?], offset: ?>> into memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>> + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + memref.store %cst, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32> + gpu.terminator + } {SCFToGPU_visited} + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + scf.for %arg15 = %c0_0 to %c5 step %c1 { + %0 = memref.load %expand_shape_1[%c0_0, %arg4, %arg15] : memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>> + %1 = memref.load %expand_shape[%c0_0, %arg15, %arg5] : memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>> + %2 = memref.load %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32> + %3 = arith.mulf %0, %1 : f32 + %4 = arith.addf %2, %3 : f32 + memref.store %4, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32> + } + gpu.terminator + } {SCFToGPU_visited} + %collapse_shape = memref.collapse_shape %view [[0, 1], [2]] : memref<1x2x10xf32> into memref<2x10xf32> + memref.copy %collapse_shape, %arg2 : memref<2x10xf32> to memref<2x10xf32> + return + } +} \ No newline at end of file diff --git a/oneflow/ir/test/Transform/softmax.mlir b/oneflow/ir/test/Transform/softmax.mlir index 5fddcc5f87e..8c1b24ed6ae 100644 --- a/oneflow/ir/test/Transform/softmax.mlir +++ b/oneflow/ir/test/Transform/softmax.mlir @@ -1,13 +1,12 @@ -// RUN: oneflow-opt %s --pass-pipeline="builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax_codegen_spec.mlir})" \ -// RUN: | oneflow-opt --convert-vector-to-gpu=use-nvgpu=1 --convert-vector-to-scf --convert-scf-to-cf --insert-ofmempool --gpu-kernel-outlining \ -// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm))' +// RUN: oneflow-opt %s --pass-pipeline="builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax_codegen_spec_no_vectorize.mlir})" \ +// RUN: | oneflow-opt --insert-ofmempool --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand --gpu-kernel-outlining \ +// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))' !tmp_tensor_t = tensor<16x128xf32> !in_tensor_t = tensor<16x128x128xf32> !out_tensor_t = tensor<16x128x128xf32> - func.func @softmax() -> !out_tensor_t { %cst_0 = arith.constant 0.0 : f32 %cst_1 = arith.constant 1.0 : f32 @@ -64,4 +63,4 @@ func.func @softmax() -> !out_tensor_t { } -> !out_tensor_t return %res: !out_tensor_t -} \ No newline at end of file +} diff --git a/oneflow/ir/test/Transform/softmax_codegen_spec_no_vectorize.mlir b/oneflow/ir/test/Transform/softmax_codegen_spec_no_vectorize.mlir new file mode 100644 index 00000000000..dabb7d30e66 --- /dev/null +++ b/oneflow/ir/test/Transform/softmax_codegen_spec_no_vectorize.mlir @@ -0,0 +1,82 @@ +// RUN: oneflow-opt %s + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + // Note: step 1, tiling and fusing linalg ops in block level. + %ops = transform.structured.match ops{["linalg.fill", "linalg.generic"]} + in %module_op : (!pdl.operation) -> !pdl.operation + + %match_0, %match_1, %match_2, %match_3, %match_end = transform.split_handle %ops + : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, + !pdl.operation, !pdl.operation) + + %forall, %_ = + transform.structured.tile_to_forall_op %match_end tile_sizes [1, 4] + ( mapping = [#gpu.block, #gpu.block] ) + + transform.structured.fuse_into_containing_op %match_3 into %forall + transform.structured.fuse_into_containing_op %match_2 into %forall + transform.structured.fuse_into_containing_op %match_1 into %forall + transform.structured.fuse_into_containing_op %match_0 into %forall + + transform.oneflow.canonicalization %module_op : (!pdl.operation) -> () + transform.oneflow.cse %module_op : (!pdl.operation) -> () + + + // Note: step 2, tiling and fusing linalg ops in thread level. + %ops_1 = transform.structured.match ops{["linalg.fill", "linalg.generic"]} + in %module_op : (!pdl.operation) -> !pdl.operation + %match_0_0, + %match_0_1, + %match_0_2, + %match_0_3, + %match_0_end = transform.split_handle %ops_1 + : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, + !pdl.operation, !pdl.operation) + + %reduction_linalg_ops = transform.merge_handles %match_0_1, + %match_0_3 + : !pdl.operation + transform.structured.tile_to_forall_op %reduction_linalg_ops tile_sizes [1, 1] + ( mapping = [#gpu.thread, #gpu.thread] ) + + %parallel_linalg_ops = transform.merge_handles %match_0_0, + %match_0_2, + %match_0_end + : !pdl.operation + transform.structured.tile_to_forall_op %parallel_linalg_ops num_threads [1, 4, 32] + ( mapping = [#gpu.thread, #gpu.thread, #gpu.thread] ) + transform.oneflow.canonicalization %module_op : (!pdl.operation) -> () + transform.oneflow.cse %module_op : (!pdl.operation) -> () + + // Note: step 3, bufferize + transform.oneflow.explicit_linalg_outcome %module_op : (!pdl.operation) -> () + + transform.bufferization.eliminate_empty_tensors %module_op + + %empty = transform.structured.match ops{["tensor.empty"]} in %module_op : (!pdl.operation) -> !pdl.operation + %empty_id = transform.cast %empty : !pdl.operation to !transform.op<"tensor.empty"> + transform.bufferization.empty_tensor_to_alloc_tensor %empty_id : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> + + %bufferized_module_op = transform.bufferization.one_shot_bufferize %module_op + {create_deallocs = false, bufferize_function_boundaries = true, allow_return_allocs = true} : (!pdl.operation) -> !pdl.operation + + // Note: step 4, post bufferize function-type-related transform + transform.oneflow.canonicalization %bufferized_module_op : (!pdl.operation) -> () + transform.oneflow.cse %bufferized_module_op : (!pdl.operation) -> () + transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> () + + %func = transform.structured.match ops{["func.func"]} in %bufferized_module_op : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_tensor_subsets %func + : (!pdl.operation) -> () + + // Note: step 5, post bufferize memory-buffer-pool transform + transform.oneflow.results_to_out_params %bufferized_module_op : (!pdl.operation) -> () + transform.oneflow.eliminate_copy %bufferized_module_op : (!pdl.operation) -> () + transform.oneflow.fold_alloc %func : (!pdl.operation) -> () + + // Note: step 6, mapping scf to gpu + %gpu_launch_op = transform.gpu.map_forall_to_blocks %bufferized_module_op { generate_gpu_launch } + transform.gpu.map_nested_forall_to_threads %gpu_launch_op block_dims = [32, 4, 1] +} +