Skip to content

Commit

Permalink
Add shardy build (#1932)
Browse files Browse the repository at this point in the history
Add Shardy build in env location
  • Loading branch information
wooseokTT authored Jan 23, 2025
1 parent 7214507 commit 3f0b569
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ include(TTMLIRPythonSitePackages)
add_subdirectory(third_party)
if (TTMLIR_ENABLE_STABLEHLO)
set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${TTMLIR_TOOLCHAIN_DIR}/src/shardy ${CMAKE_CURRENT_BINARY_DIR}/shardy EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/shardy)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/shardy)
add_subdirectory(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo ${CMAKE_CURRENT_BINARY_DIR}/stablehlo EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/stablehlo)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo)
Expand Down
11 changes: 11 additions & 0 deletions env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ project(ttmlir-toolchain LANGUAGES CXX C)
set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274")
set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29")
set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739")
set(SHARDY_VERSION "55f44c23b766be38bccb0b2394b0e8dfba45694e")

include(ExternalProject)

Expand Down Expand Up @@ -78,5 +79,15 @@ ExternalProject_Add(stablehlo
INSTALL_COMMAND ""
)

ExternalProject_Add(shardy
PREFIX ${TTMLIR_TOOLCHAIN_DIR}
GIT_REPOSITORY https://github.com/openxla/shardy.git
GIT_TAG ${SHARDY_VERSION}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
PATCH_COMMAND git apply "${CMAKE_CURRENT_LIST_DIR}/patches/shardy.patch"
)

add_custom_target(llvm-lit ALL COMMAND cp llvm-project-prefix/src/llvm-project-build/bin/llvm-lit ${TTMLIR_TOOLCHAIN_DIR}/bin/llvm-lit DEPENDS llvm-project)
add_custom_target(run-clang-tidy-install ALL COMMAND cp llvm-project-prefix/src/llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py ${TTMLIR_TOOLCHAIN_DIR}/bin/run-clang-tidy.py DEPENDS llvm-project)
152 changes: 152 additions & 0 deletions env/patches/shardy.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..5e3e97c
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,52 @@
+# Custom embedded build for Shardy, targgeting minimal build as a part of
+# tt-mlir MLIR project. This CMakeLists.txt file is mainly from StableHLO.
+
+cmake_minimum_required(VERSION 3.15.0)
+
+# CMP0116: Ninja generators transform `DEPFILE`s from `add_custom_command()`
+# New in CMake 3.20. https://cmake.org/cmake/help/latest/policy/CMP0116.html
+if(POLICY CMP0116)
+ cmake_policy(SET CMP0116 OLD)
+endif()
+
+option(SHARDY_EMBEDDED_BUILD "Build Shardy as part of another project" ON)
+option(SHARDY_ENABLE_LLD "Use LLD as the linker if available" OFF)
+
+message(STATUS "Building Shardy embedded in another project")
+project(shardy LANGUAGES CXX C)
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_CXX_STANDARD 17)
+
+find_package(MLIR REQUIRED CONFIG)
+
+set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
+set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
+list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
+list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
+include(HandleLLVMOptions)
+
+include(TableGen)
+include(AddLLVM)
+include(AddMLIR)
+
+include(CheckCXXCompilerFlag)
+include(CheckLinkerFlag)
+
+if (SHARDY_ENABLE_LLD)
+ message(STATUS "Enabling LLD as the linker")
+ add_link_options("-fuse-ld=lld")
+endif()
+
+include_directories(${LLVM_INCLUDE_DIRS})
+include_directories(${MLIR_INCLUDE_DIRS})
+include_directories(${CMAKE_CURRENT_SOURCE_DIR})
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
+link_directories(${LLVM_BUILD_LIBRARY_DIR})
+add_definitions(${LLVM_DEFINITIONS})
+
+set(SHARDY_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(SHARDY_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
+
+add_compile_options(-Wno-deprecated-declarations -Wno-unused-but-set-variable -Wno-sign-compare)
+
+add_subdirectory(shardy/dialect/sdy/ir)
diff --git a/shardy/dialect/sdy/ir/CMakeLists.txt b/shardy/dialect/sdy/ir/CMakeLists.txt
new file mode 100644
index 0000000..07d6a97
--- /dev/null
+++ b/shardy/dialect/sdy/ir/CMakeLists.txt
@@ -0,0 +1,88 @@
+# Shardy MLIR dialect.
+
+set(LLVM_TARGET_DEFINITIONS dialect.td)
+mlir_tablegen(dialect.h.inc -gen-dialect-decls -dialect=sdy)
+mlir_tablegen(dialect.cc.inc -gen-dialect-defs -dialect=sdy)
+add_public_tablegen_target(SdyDialectIncGen)
+add_dependencies(mlir-headers SdyDialectIncGen)
+add_mlir_doc(dialect SdyDialect src/autogen/md/Dialect/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS ops.td)
+mlir_tablegen(ops.h.inc -gen-op-decls)
+mlir_tablegen(ops.cc.inc -gen-op-defs)
+add_public_tablegen_target(SdyOpsIncGen)
+add_dependencies(mlir-headers SdyOpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS attrs.td)
+mlir_tablegen(attrs.h.inc -gen-attrdef-decls)
+mlir_tablegen(attrs.cc.inc -gen-attrdef-defs)
+add_public_tablegen_target(SdyAttrsIncGen)
+add_dependencies(mlir-headers SdyAttrsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS enums.td)
+mlir_tablegen(enums.h.inc -gen-enum-decls)
+mlir_tablegen(enums.cc.inc -gen-enum-defs)
+add_public_tablegen_target(SdyEnumsIncGen)
+add_dependencies(mlir-headers SdyEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS op_interface.td)
+mlir_tablegen(op_interface.h.inc -gen-op-interface-decls)
+mlir_tablegen(op_interface.cc.inc -gen-op-interface-defs)
+add_public_tablegen_target(SdyOpInterfaceIncGen)
+add_dependencies(mlir-headers SdyOpInterfaceIncGen)
+
+set(LLVM_TARGET_DEFINITIONS canonicalization.td)
+mlir_tablegen(canonicalization.cc.inc -gen-rewriters)
+add_public_tablegen_target(SdyCanonicalizationIncGen)
+add_dependencies(mlir-headers SdyCanonicalizationIncGen)
+
+add_mlir_library(SdyDialect
+ canonicalization.cc
+ data_flow_utils.cc
+ dialect.cc
+ parsers.cc
+ printers.cc
+ utils.cc
+ verifiers.cc
+
+ DEPENDS
+ SdyDialectIncGen
+ SdyOpsIncGen
+ SdyAttrsIncGen
+ SdyEnumsIncGen
+ SdyOpInterfaceIncGen
+ SdyCanonicalizationIncGen
+
+ LINK_LIBS PUBLIC
+ LLVMSupport
+ MLIRBytecodeOpInterface
+ MLIRFuncDialect
+ MLIRIR
+ MLIRInferTypeOpInterface
+ MLIRShapeDialect
+ MLIRSideEffectInterfaces
+ MLIRSupport
+ StablehloAssemblyFormat
+ StablehloOps
+ StablehloTypeInference
+)
+
+target_include_directories(SdyDialect INTERFACE
+ $<BUILD_INTERFACE:${SHARDY_SOURCE_DIR}>
+ $<BUILD_INTERFACE:${SHARDY_BINARY_DIR}>
+)
+
+add_mlir_dialect_library(SdyRegister
+ register.cc
+
+ LINK_LIBS PUBLIC
+ SdyDialect
+ MLIRFuncDialect
+ MLIRIR
+ StablehloOps
+)
+
+target_include_directories(SdyRegister INTERFACE
+ $<BUILD_INTERFACE:${SHARDY_SOURCE_DIR}>
+ $<BUILD_INTERFACE:${SHARDY_BINARY_DIR}>
+)
4 changes: 2 additions & 2 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ include "mlir/Pass/PassBase.td"
def ConvertStableHLOToTTIR : Pass<"convert-stablehlo-to-ttir", "::mlir::ModuleOp"> {
let summary = "Convert StableHLO dialect to TTIR dialect.";
let constructor = "createConvertStableHLOToTTIRPass()";
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::tt::ttir::TTIRDialect"];
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::sdy::SdyDialect", "mlir::tt::ttir::TTIRDialect"];
}
def ConvertArithToStableHLO : Pass<"convert-arith-to-stablehlo", "::mlir::ModuleOp"> {
let summary = "Convert Arith Dialect to StableHLO dialect.";
let constructor = "createConvertArithToStableHLOPass()";
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::arith::ArithDialect"];
let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::sdy::SdyDialect", "mlir::arith::ArithDialect"];
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ MLIRTTMetalPipelines
)

if (TTMLIR_ENABLE_STABLEHLO)
list(APPEND link_libs StablehloRegister)
list(APPEND link_libs StablehloRegister SdyRegister)
endif()

add_mlir_library(TTMLIRStatic STATIC RegisterAll.cpp
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>

#include <shardy/dialect/sdy/ir/dialect.h>
#include <stablehlo/dialect/StablehloOps.h>

#include "ttmlir/Dialect/TT/IR/TT.h"
Expand Down Expand Up @@ -55,6 +56,7 @@ struct ConvertArithToStableHLOPass

target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::stablehlo::StablehloDialect>();
target.addLegalDialect<mlir::sdy::SdyDialect>();
target.addLegalOp<mlir::tensor::EmptyOp>();
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<mlir::func::FuncOp>();
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/shardy)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/shardy-build)
include_directories(${TTMLIR_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/include)

Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>

#include <shardy/dialect/sdy/ir/dialect.h>
#include <stablehlo/dialect/StablehloOps.h>

#include "ttmlir/Dialect/TT/IR/TT.h"
Expand Down Expand Up @@ -93,6 +94,7 @@ struct ConvertStableHLOToTTIRPass

target.addIllegalDialect<mlir::stablehlo::StablehloDialect>();

target.addLegalDialect<mlir::sdy::SdyDialect>();
target.addLegalDialect<ttir::TTIRDialect>();
target.addLegalOp<mlir::tensor::EmptyOp>();
target.addLegalOp<mlir::ModuleOp>();
Expand Down
4 changes: 3 additions & 1 deletion lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"

#ifdef TTMLIR_ENABLE_STABLEHLO
#if TTMLIR_ENABLE_STABLEHLO
#include "shardy/dialect/sdy/ir/register.h"
#include "stablehlo/dialect/Register.h"
#endif

Expand All @@ -43,6 +44,7 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect>();
#if TTMLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerAllDialects(registry);
mlir::sdy::registerAllDialects(registry);
#endif
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
Expand Down

0 comments on commit 3f0b569

Please sign in to comment.