Skip to content

Commit

Permalink
[CodeGen]Support matmul codegen and runtime in mlir jit (#10283)
Browse files Browse the repository at this point in the history
本pr支持在mlir jit中运行matmul op,该pr补全了在gpu dialect之后的codegen中健壮性的一部分问题
  • Loading branch information
howin98 authored May 30, 2023
1 parent 08ded68 commit 4ac3692
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 58 deletions.
6 changes: 4 additions & 2 deletions oneflow/ir/lib/OneFlow/PDLL/AllocEliminationPatterns.pdll
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "OneFlow/OneFlowOps.td"

Constraint IsFuncArguments(value: Value) [{
return success(llvm::dyn_cast<mlir::BlockArgument>(value) != nullptr);
return success(llvm::dyn_cast<mlir::BlockArgument>(value));
}];

Pattern {
arg: Value;
let alloc = op<memref.alloc>();
let copy = op<memref.copy>(alloc.0, arg: IsFuncArguments);
let copy = op<memref.copy>(alloc.0, arg);
IsFuncArguments(arg);

rewrite alloc with {
erase copy;
Expand Down
118 changes: 67 additions & 51 deletions oneflow/ir/lib/OneFlow/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ limitations under the License.
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -121,7 +122,6 @@ limitations under the License.
#include <string>

#ifdef WITH_MLIR_CUDA_CODEGEN
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
Expand Down Expand Up @@ -1016,79 +1016,95 @@ struct KernelLaunchWithCudaGraphPattern : public KernelLaunchSimplePattern {
}
};

void AddLowerToLinalgMemRefPasses(PassManager& pm) {
pm.addPass(createConvertToSignlessForTosaPass()); // convert-to-signless-for-tosa
pm.addNestedPass<func::FuncOp>(LLVM::createRequestCWrappersPass()); // llvm-request-c-wrappers
pm.addPass(createLowerOneFlowToTosaPass()); // lower-oneflow-to-tosa
pm.addNestedPass<func::FuncOp>(
tosa::createTosaMakeBroadcastablePass()); // tosa-make-broadcastable
pm.addPass(createCSEPass()); // cse
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg()); // tosa-to-linalg-on-tensors
pm.addNestedPass<func::FuncOp>(tosa::createTosaToTensor()); // tosa-to-tensor
pm.addNestedPass<func::FuncOp>(
createLinalgElementwiseOpFusionPass()); // linalg-fuse-elementwise-ops
void AddLoweringToLinalgMemRefPasses(PassManager& pm) {
pm.addPass(createConvertToSignlessForTosaPass());
pm.addNestedPass<func::FuncOp>(LLVM::createRequestCWrappersPass());
pm.addPass(createLowerOneFlowToTosaPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addPass(createCSEPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
pm.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
pm.addNestedPass<func::FuncOp>(createLinalgElementwiseOpFusionPass());
// TODO: more optimization pass
// Note: OneShot bufferization with result extract realization.
pm.addPass(bufferization::createEmptyTensorEliminationPass()); // eliminate-empty-tensors
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); // empty-tensor-to-alloc-tensor
pm.addPass(bufferization::createEmptyTensorEliminationPass());
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());

auto oneshot_bufferize = bufferization::createOneShotBufferizePass();
CHECK(
oneshot_bufferize
->initializeOptions("create-deallocs=0 bufferize-function-boundaries allow-return-allocs")
.succeeded());
pm.addPass(std::move(oneshot_bufferize));
pm.addPass(bufferization::createBufferResultsToOutParamsPass()); // buffer-results-to-out-params
pm.addPass(mlir::oneflow::createEliminateAllocOpsPass()); // eliminate-alloc-ops
pm.addPass(createCanonicalizerPass()); // canonicalize
pm.addPass(bufferization::createBufferResultsToOutParamsPass());
pm.addPass(mlir::oneflow::createEliminateAllocOpsPass());
pm.addPass(createCanonicalizerPass());
}

LogicalResult LowerModuleToLLVM(mlir::MLIRContext* context, ModuleOp module) {
mlir::PassManager pm(context);
mlir::oneflow::CheckEnableIRPrinting(pm);
AddLowerToLinalgMemRefPasses(pm);
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass()); // convert-linalg-to-loops
pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass()); // convert-scf-to-cf
pm.addPass(createConvertLinalgToLLVMPass()); // convert-linalg-to-llvm
pm.addNestedPass<func::FuncOp>(createFoldAllocToSubviewPass()); // fold-alloc-to-subview
pm.addPass(createInsertOneFlowMemPoolPass()); // insert-ofmempool
pm.addPass(createAppendOneFlowStreamPass()); // append-ofstream
pm.addPass(createFinalizeMemRefToLLVMConversionPass()); // convert-memref-to-llvm
pm.addPass(createConvertFuncToLLVMPass()); // convert-func-to-llvm
pm.addPass(memref::createExpandStridedMetadataPass()); // expand-strided-metadata
pm.addPass(createFinalizeMemRefToLLVMConversionPass()); // convert-memref-to-llvm
pm.addPass(createReconcileUnrealizedCastsPass()); // reconcile-unrealized-casts
AddLoweringToLinalgMemRefPasses(pm);
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
pm.addNestedPass<func::FuncOp>(createFoldAllocToSubviewPass());
pm.addPass(createInsertOneFlowMemPoolPass());
pm.addPass(createAppendOneFlowStreamPass());
pm.addPass(memref::createExpandOpsPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createConvertLinalgToLLVMPass());
pm.addPass(createConvertFuncToLLVMPass());
pm.addPass(createReconcileUnrealizedCastsPass());
return pm.run(module);
}

#ifdef WITH_MLIR_CUDA_CODEGEN

void AddLoweringLinalgOnBufferToGpuWithStdPasses(PassManager& pm) {
pm.addNestedPass<func::FuncOp>(createConvertLinalgToParallelLoopsPass());
pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass());
pm.addNestedPass<func::FuncOp>(createParallelLoopToGpuPass());
pm.addNestedPass<func::FuncOp>(createGpuLauchSinkIndexComputationsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createFoldAllocToSubviewPass());
pm.addPass(createInsertOneFlowMemPoolPass());
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createGpuCopyArgPass());
}

void AddAdheringCubinToGpuModulePasses(PassManager& pm) {
pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
pm.addNestedPass<gpu::GPUModuleOp>(createLowerGpuOpsToNVVMOpsPass());
pm.addNestedPass<gpu::GPUModuleOp>(createNVVMToCubinPass());
}

void AddLoweringGpuToLLVMPasses(PassManager& pm) {
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createAppendOneFlowStreamPass());
pm.addPass(createGpuToLLVMConversionPass());
pm.addPass(createMgpuToOneFlowStreamPass());
pm.addPass(createReconcileUnrealizedCastsPass());
}

LogicalResult LowerModuleToCUDALLVM(mlir::MLIRContext* context, ModuleOp module) {
InitializeLLVMNVPTXBackend();
mlir::PassManager pm(context);
mlir::oneflow::CheckEnableIRPrinting(pm);
AddLowerToLinalgMemRefPasses(pm);
pm.addNestedPass<func::FuncOp>(
createConvertLinalgToParallelLoopsPass()); // convert-linalg-to-parallel-loops
pm.addNestedPass<func::FuncOp>(createGpuMapParallelLoopsPass()); // gpu-map-parallel-loops
pm.addPass(createParallelLoopToGpuPass()); // convert-parallel-loops-to-gpu
pm.addPass(createGpuLauchSinkIndexComputationsPass());
pm.addPass(createGpuKernelOutliningPass()); // gpu-kernel-outlining
pm.addPass(createCanonicalizerPass()); // canonicalize
pm.addNestedPass<func::FuncOp>(createFoldAllocToSubviewPass()); // fold-alloc-to-subview
pm.addPass(createInsertOneFlowMemPoolPass()); // insert-ofmempool
// -pass-pipeline='gpu.module([PASS1][PASS2]...)'
pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass()); // strip-debuginfo
pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass()); // lower-affine
pm.addNestedPass<gpu::GPUModuleOp>(createLowerGpuOpsToNVVMOpsPass()); // convert-gpu-to-nvvm
pm.addNestedPass<gpu::GPUModuleOp>(createNVVMToCubinPass()); // out-of-tree-gpu-to-cubin
pm.addNestedPass<func::FuncOp>(createGpuCopyArgPass()); // buffer-host-register
pm.addPass(createAppendOneFlowStreamPass()); // append-ofstream
pm.addPass(createGpuToLLVMConversionPass()); // gpu-to-llvm
pm.addPass(createMgpuToOneFlowStreamPass()); // gpu-to-llvm
pm.addPass(memref::createExpandStridedMetadataPass()); // expand-strided-metadata
pm.addPass(createFinalizeMemRefToLLVMConversionPass()); // convert-memref-to-llvm
pm.addPass(createReconcileUnrealizedCastsPass()); // reconcile-unrealized-casts
AddLoweringToLinalgMemRefPasses(pm);
AddLoweringLinalgOnBufferToGpuWithStdPasses(pm);
pm.addPass(memref::createExpandOpsPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createGpuKernelOutliningPass());
AddAdheringCubinToGpuModulePasses(pm);
AddLoweringGpuToLLVMPasses(pm);
return pm.run(module);
}

Expand Down
77 changes: 77 additions & 0 deletions oneflow/ir/test/OneFlow/cuda_code_gen/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s
# CHECK: jit

import unittest
import numpy as np

import os


import oneflow as flow
import oneflow.unittest


class MatMulModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.w = flow.nn.Parameter(flow.Tensor(5, 10))
self.b = flow.nn.Parameter(flow.Tensor(10))

def forward(self, x):
return flow.matmul(x, self.w) + self.b


def do_matmul_graph(test_case, with_cuda=False):
x = flow.randn(2, 5)
module_to_run = MatMulModule()
if with_cuda:
x = x.cuda()
module_to_run = module_to_run.to("cuda")
y_eager = module_to_run(x)

class GraphToRun(flow.nn.Graph):
def __init__(self):
super().__init__()
self.fw = module_to_run

def build(self, x):
return self.fw(x)

graph_to_run = GraphToRun()
y_lazy = graph_to_run(x)
test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))


@flow.unittest.skip_unless_1n1d()
class TestFuseCastScale(oneflow.unittest.MLIRTestCase):
def setUp(self):
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"

def test_relu_graph(test_case):
import oneflow.sysconfig

if oneflow.sysconfig.with_cuda():
do_matmul_graph(test_case, True)

do_matmul_graph(test_case)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions oneflow/ir/test/Transform/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
if not config.WITH_MLIR_CUDA_CODEGEN:
config.unsupported = True
36 changes: 36 additions & 0 deletions oneflow/ir/test/Transform/matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: oneflow-opt %s --insert-ofmempool --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand --gpu-kernel-outlining \
// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))'

module {
func.func @JITOpGenerated0(%arg0: memref<5x10xf32, strided<[?, ?], offset: ?>>, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>>, %arg2: memref<2x10xf32>) attributes {llvm.emit_c_interface} {
%alloc = memref.alloc() : memref<512xi8>
%c0 = arith.constant 0 : index
%view = memref.view %alloc[%c0][] : memref<512xi8> to memref<1x2x10xf32>
%c10 = arith.constant 10 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0_0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%cst = arith.constant 0.000000e+00 : f32
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] : memref<5x10xf32, strided<[?, ?], offset: ?>> into memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>>
%expand_shape_1 = memref.expand_shape %arg1 [[0, 1], [2]] : memref<2x5xf32, strided<[?, ?], offset: ?>> into memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>>
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
memref.store %cst, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32>
gpu.terminator
} {SCFToGPU_visited}
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c2, %arg11 = %c10) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
scf.for %arg15 = %c0_0 to %c5 step %c1 {
%0 = memref.load %expand_shape_1[%c0_0, %arg4, %arg15] : memref<1x2x5xf32, strided<[?, ?, ?], offset: ?>>
%1 = memref.load %expand_shape[%c0_0, %arg15, %arg5] : memref<1x5x10xf32, strided<[?, ?, ?], offset: ?>>
%2 = memref.load %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
memref.store %4, %view[%c0_0, %arg4, %arg5] : memref<1x2x10xf32>
}
gpu.terminator
} {SCFToGPU_visited}
%collapse_shape = memref.collapse_shape %view [[0, 1], [2]] : memref<1x2x10xf32> into memref<2x10xf32>
memref.copy %collapse_shape, %arg2 : memref<2x10xf32> to memref<2x10xf32>
return
}
}
9 changes: 4 additions & 5 deletions oneflow/ir/test/Transform/softmax.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// RUN: oneflow-opt %s --pass-pipeline="builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax_codegen_spec.mlir})" \
// RUN: | oneflow-opt --convert-vector-to-gpu=use-nvgpu=1 --convert-vector-to-scf --convert-scf-to-cf --insert-ofmempool --gpu-kernel-outlining \
// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm))'
// RUN: oneflow-opt %s --pass-pipeline="builtin.module(oneflow-transform-dialect-interpreter{transform-file-name=%p/softmax_codegen_spec_no_vectorize.mlir})" \
// RUN: | oneflow-opt --insert-ofmempool --convert-linalg-to-loops --convert-scf-to-cf --canonicalize --cse --memref-expand --gpu-kernel-outlining \
// RUN: | oneflow-opt --pass-pipeline='builtin.module(gpu.module(expand-strided-metadata,lower-affine,strip-debuginfo,convert-gpu-to-nvvm,nvvm-to-cubin))'


!tmp_tensor_t = tensor<16x128xf32>
!in_tensor_t = tensor<16x128x128xf32>
!out_tensor_t = tensor<16x128x128xf32>


func.func @softmax() -> !out_tensor_t {
%cst_0 = arith.constant 0.0 : f32
%cst_1 = arith.constant 1.0 : f32
Expand Down Expand Up @@ -64,4 +63,4 @@ func.func @softmax() -> !out_tensor_t {
} -> !out_tensor_t

return %res: !out_tensor_t
}
}
Loading

0 comments on commit 4ac3692

Please sign in to comment.