Skip to content

Commit

Permalink
wip: adding --emit-spirv-mlir path
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugobros3 committed Oct 9, 2024
1 parent 3086002 commit fd0c839
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 1 deletion.
32 changes: 32 additions & 0 deletions clang/include/clang/CIR/LowerToSPIRV.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//====- LowerToSPIRV.h- Lowering from CIR to LLVM -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares an interface for converting MLIR modules to SPIR-V
//
//===----------------------------------------------------------------------===//
#ifndef CLANG_CIR_LOWERTOSPIRV_H
#define CLANG_CIR_LOWERTOSPIRV_H

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

#include <memory>

namespace llvm {
class LLVMContext;
class Module;
} // namespace llvm

namespace mlir {
class MLIRContext;
class ModuleOp;

spirv::ModuleOp lowerFromMLIRToSPIRV(mlir::ModuleOp theModule,
mlir::MLIRContext *mlirCtx);
} // namespace mlir

#endif // CLANG_CIR_LOWERTOSPIRV_H
4 changes: 4 additions & 0 deletions clang/include/clang/CIR/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ namespace cir {
/// `Std`, to the LLVM dialect for codegen.
std::unique_ptr<mlir::Pass> createConvertMLIRToLLVMPass();

/// Create a pass for lowering from MLIR builtin dialects to the SPIRV dialect for codegen.
std::unique_ptr<mlir::Pass> createConvertMLIRToSPIRVPass();

/// Create a pass that fully lowers CIR to the MLIR in-tree dialects.
std::unique_ptr<mlir::Pass> createConvertCIRToMLIRPass();


namespace direct {
/// Create a pass that fully lowers CIR to the LLVMIR dialect.
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass();
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/CIRFrontendAction/CIRGenAction.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
EmitAssembly,
EmitCIR,
EmitCIRFlat,
EmitSPIRV,
EmitLLVM,
EmitBC,
EmitMLIR,
Expand Down Expand Up @@ -101,6 +102,13 @@ class EmitMLIRAction : public CIRGenAction {
EmitMLIRAction(mlir::MLIRContext *mlirCtx = nullptr);
};

class EmitSPIRVAction : public CIRGenAction {
virtual void anchor();

public:
EmitSPIRVAction(mlir::MLIRContext *mlirCtx = nullptr);
};

class EmitLLVMAction : public CIRGenAction {
virtual void anchor();

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -3055,6 +3055,8 @@ def emit_cir_flat : Flag<["-"], "emit-cir-flat">, Visibility<[ClangOption, CC1Op
Group<Action_Group>, HelpText<"Similar to -emit-cir but also lowers structured CFG into basic blocks.">;
def emit_mlir : Flag<["-"], "emit-mlir">, Visibility<[CC1Option]>, Group<Action_Group>,
HelpText<"Build ASTs and then lower through ClangIR to MLIR, emit the .milr file">;
def emit_spirv_mlir : Flag<["-"], "emit-spirv-mlir">, Visibility<[CC1Option]>, Group<Action_Group>,
HelpText<"Build ASTs and then lower through ClangIR to MLIR to SPIR-V, emit the .spv file">;
/// ClangIR-specific options - END

def flto : Flag<["-"], "flto">,
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Frontend/FrontendOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ enum ActionKind {
/// Emit a .mlir file
EmitMLIR,

/// Emit a .spv file
EmitSPIRV,

/// Emit a .ll file.
EmitLLVM,

Expand Down
20 changes: 20 additions & 0 deletions clang/lib/CIR/FrontendAction/CIRGenAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Target/SPIRV/Serialization.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/DeclCXX.h"
Expand All @@ -26,6 +27,7 @@
#include "clang/CIR/CIRToCIRPasses.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/LowerToSPIRV.h"
#include "clang/CIR/Passes.h"
#include "clang/CodeGen/BackendUtil.h"
#include "clang/CodeGen/ModuleBuilder.h"
Expand Down Expand Up @@ -278,6 +280,20 @@ class CIRGenConsumer : public clang::ASTConsumer {
loweredMlirModule->print(*outputStream, flags);
break;
}
case CIRGenAction::OutputType::EmitSPIRV: {
auto loweredMlirModule = lowerFromCIRToMLIR(mlirMod, mlirCtx.get());
auto spirvModule = ::mlir::lowerFromMLIRToSPIRV(loweredMlirModule, mlirCtx.get());
assert(outputStream && "Why are we here without an output stream?");
// FIXME: we cannot roundtrip prettyForm=true right now.
mlir::OpPrintingFlags flags;
flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/false);
// loweredMlirModule->print(*outputStream, flags);
SmallVector<uint32_t, 0> words;
::mlir::spirv::SerializationOptions options;
(void) ::mlir::spirv::serialize(spirvModule, words, options);
outputStream->write(reinterpret_cast<char*>(words.data()), words.size_in_bytes());
break;
}
case CIRGenAction::OutputType::EmitLLVM:
case CIRGenAction::OutputType::EmitBC:
case CIRGenAction::OutputType::EmitObj:
Expand Down Expand Up @@ -464,6 +480,10 @@ void EmitMLIRAction::anchor() {}
EmitMLIRAction::EmitMLIRAction(mlir::MLIRContext *_MLIRContext)
: CIRGenAction(OutputType::EmitMLIR, _MLIRContext) {}

void EmitSPIRVAction::anchor() {}
EmitSPIRVAction::EmitSPIRVAction(mlir::MLIRContext *_MLIRContext)
: CIRGenAction(OutputType::EmitSPIRV, _MLIRContext) {}

void EmitLLVMAction::anchor() {}
EmitLLVMAction::EmitLLVMAction(mlir::MLIRContext *_MLIRContext)
: CIRGenAction(OutputType::EmitLLVM, _MLIRContext) {}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_clang_library(clangCIRLoweringThroughMLIR
LowerCIRLoopToSCF.cpp
LowerCIRToMLIR.cpp
LowerMLIRToLLVM.cpp
LowerMLIRToSPIRV.cpp

DEPENDS
MLIRCIROpsIncGen
Expand Down
25 changes: 25 additions & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerMLIRToSPIRV.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/PassManager.h"
#include "clang/CIR/Passes.h"
#include "llvm/Support/TimeProfiler.h"

#include "mlir/Dialect/SPIRV/Transforms/Passes.h"

namespace mlir {

mlir::ModuleOp lowerFromMLIRToSPIRV(mlir::ModuleOp theModule,
mlir::MLIRContext *mlirCtx) {
llvm::TimeTraceScope scope("Lower from MLIR to SPIR-V");

mlir::PassManager pm(mlirCtx);

// TODO
//pm.addPass(cir::createConvertMLIRToSPIRVPass());

//auto result = !mlir::failed(pm.run(theModule));
//if (!result)
// report_fatal_error("The pass manager failed to lower CIR to SPIRV dialect!");
}

} // namespace mlir
1 change: 1 addition & 0 deletions clang/lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,7 @@ static const auto &getFrontendActionTable() {
{frontend::EmitCIRFlat, OPT_emit_cir_flat},
{frontend::EmitCIROnly, OPT_emit_cir_only},
{frontend::EmitMLIR, OPT_emit_mlir},
{frontend::EmitSPIRV, OPT_emit_spirv_mlir},
{frontend::EmitHTML, OPT_emit_html},
{frontend::EmitLLVM, OPT_emit_llvm},
{frontend::EmitLLVMOnly, OPT_emit_llvm_only},
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/FrontendTool/ExecuteCompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ CreateFrontendBaseAction(CompilerInstance &CI) {
auto UseCIR = CI.getFrontendOpts().UseClangIRPipeline;
auto Act = CI.getFrontendOpts().ProgramAction;
auto CIRAnalysisOnly = CI.getFrontendOpts().ClangIRAnalysisOnly;
auto EmitsCIR = Act == EmitCIR || Act == EmitCIRFlat || Act == EmitCIROnly;
auto EmitsCIR = Act == EmitCIR || Act == EmitCIRFlat || Act == EmitCIROnly || Act == EmitSPIRV;

if (!UseCIR && EmitsCIR)
llvm::report_fatal_error(
Expand Down Expand Up @@ -96,6 +96,7 @@ CreateFrontendBaseAction(CompilerInstance &CI) {
return std::make_unique<::cir::EmitCIRFlatAction>();
case EmitCIROnly: return std::make_unique<::cir::EmitCIROnlyAction>();
case EmitMLIR: return std::make_unique<::cir::EmitMLIRAction>();
case EmitSPIRV: return std::make_unique<::cir::EmitSPIRVAction>();
#else
case EmitCIR:
case EmitCIRFlat:
Expand Down

0 comments on commit fd0c839

Please sign in to comment.